Autocast shouldn't be on dtype float32

This commit is contained in:
rockerBOO
2025-03-25 19:05:03 -04:00
parent c5c07a40c5
commit 54d4de0e72
2 changed files with 7 additions and 10 deletions

View File

@@ -14,7 +14,7 @@ def initialize_urae(org_module: torch.nn.Module, lora_down: torch.nn.Module, lor
weight = org_module.weight.data.to(device, dtype=torch.float32)
with torch.autocast(device.type, dtype=torch.float32):
with torch.autocast(device.type):
# SVD decomposition
V, S, Uh = torch.linalg.svd(weight, full_matrices=False)
@@ -58,7 +58,7 @@ def initialize_pissa(org_module: torch.nn.Module, lora_down: torch.nn.Module, lo
weight = org_module.weight.data.clone().to(device, dtype=torch.float32)
with torch.autocast(device.type, dtype=torch.float32):
with torch.autocast(device.type):
# USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel},
V, S, Uh = torch.linalg.svd(weight, full_matrices=False)
Vr = V[:, : rank]

View File

@@ -7,16 +7,13 @@
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
import math
import os
from typing import Dict, List, Optional, Tuple, Type, Union
from typing import Dict, List, Optional, Type, Union
from diffusers import AutoencoderKL
from transformers import CLIPTextModel
import numpy as np
import torch
from torch import Tensor
from tqdm import tqdm
import re
from library.utils import setup_logging
from library.lora_util import initialize_lora, initialize_pissa, initialize_urae
@@ -80,7 +77,7 @@ class LoRAModule(torch.nn.Module):
)
self.lora_up = torch.nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims])
with torch.autocast(org_module.weight.device.type), torch.no_grad():
with torch.autocast("cuda"), torch.no_grad():
self.initialize_weights(org_module)
# same as microsoft's
@@ -99,7 +96,7 @@ class LoRAModule(torch.nn.Module):
self._org_lora_up = self.lora_up.weight.data.detach().clone()
self._org_lora_down = self.lora_down.weight.data.detach().clone()
elif self.initialize == "pissa":
initialize_pissa(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim)
initialize_pissa(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim, device=torch.device("cuda"))
# Need to store the original weights so we can get a plain LoRA out
self._org_lora_up = self.lora_up.weight.data.detach().clone()
self._org_lora_down = self.lora_down.weight.data.detach().clone()
@@ -115,7 +112,7 @@ class LoRAModule(torch.nn.Module):
self._org_lora_up = lora_up.weight.data.detach().clone()
self._org_lora_down = lora_down.weight.data.detach().clone()
elif self.initialize == "pissa":
initialize_pissa(org_module, lora_down, lora_up, self.scale, self.lora_dim)
initialize_pissa(org_module, lora_down, lora_up, self.scale, self.lora_dim, device=torch.device("cuda"))
# Need to store the original weights so we can get a plain LoRA out
self._org_lora_up = lora_up.weight.data.detach().clone()
self._org_lora_down = lora_down.weight.data.detach().clone()
@@ -1090,7 +1087,7 @@ class LoRANetwork(torch.nn.Module):
state_dict = self.state_dict()
if self.initialize in ['pissa']:
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
loras: List[Union[LoRAModule, LoRAInfModule]] = self.text_encoder_loras + self.unet_loras
def convert_pissa_to_standard_lora(trained_up: Tensor, trained_down: Tensor, orig_up: Tensor, orig_down: Tensor, rank: int):
# Calculate ΔW = A'B' - AB
delta_w = (trained_up @ trained_down) - (orig_up @ orig_down)