Add training metadata to output LoRA model

This commit is contained in:
space-nuko
2023-01-10 02:49:52 -08:00
parent a84ca297bd
commit 2e4ce0fdff
4 changed files with 72 additions and 5 deletions

View File

@@ -135,7 +135,7 @@ def svd(args):
if dir_name and not os.path.exists(dir_name):
os.makedirs(dir_name, exist_ok=True)
lora_network_o.save_weights(args.save_to, save_dtype)
lora_network_o.save_weights(args.save_to, save_dtype, {})
print(f"LoRA weights are saved to: {args.save_to}")

View File

@@ -6,6 +6,8 @@
import math
import os
import torch
import zipfile
import json
class LoRAModule(torch.nn.Module):
@@ -61,6 +63,7 @@ class LoRANetwork(torch.nn.Module):
super().__init__()
self.multiplier = multiplier
self.lora_dim = lora_dim
self.metadata = {}
# create module instances
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]:
@@ -91,11 +94,17 @@ class LoRANetwork(torch.nn.Module):
names.add(lora.lora_name)
def load_weights(self, file):
self.metadata = {}
if os.path.splitext(file)[1] == '.safetensors':
from safetensors.torch import load_file
self.weights_sd = load_file(file)
self.metadata = self.weights_sd.metadata()
else:
self.weights_sd = torch.load(file, map_location='cpu')
with zipfile.ZipFile(file, "w") as zipf:
if "sd_scripts_metadata.json" in zipf.namelist():
with zipf.open("sd_scripts_metadata.json", "r") as jsfile:
self.metadata = json.load(jsfile)
def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None):
if self.weights_sd:
@@ -174,7 +183,8 @@ class LoRANetwork(torch.nn.Module):
def get_trainable_params(self):
return self.parameters()
def save_weights(self, file, dtype):
def save_weights(self, file, dtype, metadata):
self.metadata = metadata
state_dict = self.state_dict()
if dtype is not None:
@@ -185,6 +195,8 @@ class LoRANetwork(torch.nn.Module):
if os.path.splitext(file)[1] == '.safetensors':
from safetensors.torch import save_file
save_file(state_dict, file)
save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)
with zipfile.ZipFile(file, "w") as zipf:
zipf.writestr("sd_scripts_metadata.json", json.dumps(metadata))