mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Add training metadata to output LoRA model
This commit is contained in:
@@ -6,6 +6,8 @@
|
||||
import math
|
||||
import os
|
||||
import torch
|
||||
import zipfile
|
||||
import json
|
||||
|
||||
|
||||
class LoRAModule(torch.nn.Module):
|
||||
@@ -61,6 +63,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.multiplier = multiplier
|
||||
self.lora_dim = lora_dim
|
||||
self.metadata = {}
|
||||
|
||||
# create module instances
|
||||
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]:
|
||||
@@ -91,11 +94,17 @@ class LoRANetwork(torch.nn.Module):
|
||||
names.add(lora.lora_name)
|
||||
|
||||
def load_weights(self, file):
|
||||
self.metadata = {}
|
||||
if os.path.splitext(file)[1] == '.safetensors':
|
||||
from safetensors.torch import load_file
|
||||
self.weights_sd = load_file(file)
|
||||
self.metadata = self.weights_sd.metadata()
|
||||
else:
|
||||
self.weights_sd = torch.load(file, map_location='cpu')
|
||||
with zipfile.ZipFile(file, "w") as zipf:
|
||||
if "sd_scripts_metadata.json" in zipf.namelist():
|
||||
with zipf.open("sd_scripts_metadata.json", "r") as jsfile:
|
||||
self.metadata = json.load(jsfile)
|
||||
|
||||
def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None):
|
||||
if self.weights_sd:
|
||||
@@ -174,7 +183,8 @@ class LoRANetwork(torch.nn.Module):
|
||||
def get_trainable_params(self):
|
||||
return self.parameters()
|
||||
|
||||
def save_weights(self, file, dtype):
|
||||
def save_weights(self, file, dtype, metadata):
|
||||
self.metadata = metadata
|
||||
state_dict = self.state_dict()
|
||||
|
||||
if dtype is not None:
|
||||
@@ -185,6 +195,8 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
if os.path.splitext(file)[1] == '.safetensors':
|
||||
from safetensors.torch import save_file
|
||||
save_file(state_dict, file)
|
||||
save_file(state_dict, file, metadata)
|
||||
else:
|
||||
torch.save(state_dict, file)
|
||||
with zipfile.ZipFile(file, "w") as zipf:
|
||||
zipf.writestr("sd_scripts_metadata.json", json.dumps(metadata))
|
||||
|
||||
Reference in New Issue
Block a user