mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
Autocast shouldn't be on dtype float32
This commit is contained in:
@@ -14,7 +14,7 @@ def initialize_urae(org_module: torch.nn.Module, lora_down: torch.nn.Module, lor
|
||||
|
||||
weight = org_module.weight.data.to(device, dtype=torch.float32)
|
||||
|
||||
with torch.autocast(device.type, dtype=torch.float32):
|
||||
with torch.autocast(device.type):
|
||||
# SVD decomposition
|
||||
V, S, Uh = torch.linalg.svd(weight, full_matrices=False)
|
||||
|
||||
@@ -58,7 +58,7 @@ def initialize_pissa(org_module: torch.nn.Module, lora_down: torch.nn.Module, lo
|
||||
|
||||
weight = org_module.weight.data.clone().to(device, dtype=torch.float32)
|
||||
|
||||
with torch.autocast(device.type, dtype=torch.float32):
|
||||
with torch.autocast(device.type):
|
||||
# USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel},
|
||||
V, S, Uh = torch.linalg.svd(weight, full_matrices=False)
|
||||
Vr = V[:, : rank]
|
||||
|
||||
@@ -7,16 +7,13 @@
|
||||
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
|
||||
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
|
||||
|
||||
import math
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
from typing import Dict, List, Optional, Type, Union
|
||||
from diffusers import AutoencoderKL
|
||||
from transformers import CLIPTextModel
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from tqdm import tqdm
|
||||
import re
|
||||
from library.utils import setup_logging
|
||||
from library.lora_util import initialize_lora, initialize_pissa, initialize_urae
|
||||
|
||||
@@ -80,7 +77,7 @@ class LoRAModule(torch.nn.Module):
|
||||
)
|
||||
self.lora_up = torch.nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims])
|
||||
|
||||
with torch.autocast(org_module.weight.device.type), torch.no_grad():
|
||||
with torch.autocast("cuda"), torch.no_grad():
|
||||
self.initialize_weights(org_module)
|
||||
|
||||
# same as microsoft's
|
||||
@@ -99,7 +96,7 @@ class LoRAModule(torch.nn.Module):
|
||||
self._org_lora_up = self.lora_up.weight.data.detach().clone()
|
||||
self._org_lora_down = self.lora_down.weight.data.detach().clone()
|
||||
elif self.initialize == "pissa":
|
||||
initialize_pissa(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim)
|
||||
initialize_pissa(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim, device=torch.device("cuda"))
|
||||
# Need to store the original weights so we can get a plain LoRA out
|
||||
self._org_lora_up = self.lora_up.weight.data.detach().clone()
|
||||
self._org_lora_down = self.lora_down.weight.data.detach().clone()
|
||||
@@ -115,7 +112,7 @@ class LoRAModule(torch.nn.Module):
|
||||
self._org_lora_up = lora_up.weight.data.detach().clone()
|
||||
self._org_lora_down = lora_down.weight.data.detach().clone()
|
||||
elif self.initialize == "pissa":
|
||||
initialize_pissa(org_module, lora_down, lora_up, self.scale, self.lora_dim)
|
||||
initialize_pissa(org_module, lora_down, lora_up, self.scale, self.lora_dim, device=torch.device("cuda"))
|
||||
# Need to store the original weights so we can get a plain LoRA out
|
||||
self._org_lora_up = lora_up.weight.data.detach().clone()
|
||||
self._org_lora_down = lora_down.weight.data.detach().clone()
|
||||
@@ -1090,7 +1087,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
state_dict = self.state_dict()
|
||||
|
||||
if self.initialize in ['pissa']:
|
||||
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
||||
loras: List[Union[LoRAModule, LoRAInfModule]] = self.text_encoder_loras + self.unet_loras
|
||||
def convert_pissa_to_standard_lora(trained_up: Tensor, trained_down: Tensor, orig_up: Tensor, orig_down: Tensor, rank: int):
|
||||
# Calculate ΔW = A'B' - AB
|
||||
delta_w = (trained_up @ trained_down) - (orig_up @ orig_down)
|
||||
|
||||
Reference in New Issue
Block a user