mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Add training metadata to output LoRA model
This commit is contained in:
@@ -747,6 +747,20 @@ def exists(val):
|
|||||||
def default(val, d):
|
def default(val, d):
|
||||||
return val if exists(val) else d
|
return val if exists(val) else d
|
||||||
|
|
||||||
|
|
||||||
|
def model_hash(filename):
|
||||||
|
try:
|
||||||
|
with open(filename, "rb") as file:
|
||||||
|
import hashlib
|
||||||
|
m = hashlib.sha256()
|
||||||
|
|
||||||
|
file.seek(0x100000)
|
||||||
|
m.update(file.read(0x10000))
|
||||||
|
return m.hexdigest()[0:8]
|
||||||
|
except FileNotFoundError:
|
||||||
|
return 'NOFILE'
|
||||||
|
|
||||||
|
|
||||||
# flash attention forwards and backwards
|
# flash attention forwards and backwards
|
||||||
|
|
||||||
# https://arxiv.org/abs/2205.14135
|
# https://arxiv.org/abs/2205.14135
|
||||||
|
|||||||
@@ -135,7 +135,7 @@ def svd(args):
|
|||||||
if dir_name and not os.path.exists(dir_name):
|
if dir_name and not os.path.exists(dir_name):
|
||||||
os.makedirs(dir_name, exist_ok=True)
|
os.makedirs(dir_name, exist_ok=True)
|
||||||
|
|
||||||
lora_network_o.save_weights(args.save_to, save_dtype)
|
lora_network_o.save_weights(args.save_to, save_dtype, {})
|
||||||
print(f"LoRA weights are saved to: {args.save_to}")
|
print(f"LoRA weights are saved to: {args.save_to}")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,8 @@
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
|
import zipfile
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
class LoRAModule(torch.nn.Module):
|
class LoRAModule(torch.nn.Module):
|
||||||
@@ -61,6 +63,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.multiplier = multiplier
|
self.multiplier = multiplier
|
||||||
self.lora_dim = lora_dim
|
self.lora_dim = lora_dim
|
||||||
|
self.metadata = {}
|
||||||
|
|
||||||
# create module instances
|
# create module instances
|
||||||
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]:
|
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]:
|
||||||
@@ -91,11 +94,17 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
names.add(lora.lora_name)
|
names.add(lora.lora_name)
|
||||||
|
|
||||||
def load_weights(self, file):
|
def load_weights(self, file):
|
||||||
|
self.metadata = {}
|
||||||
if os.path.splitext(file)[1] == '.safetensors':
|
if os.path.splitext(file)[1] == '.safetensors':
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
self.weights_sd = load_file(file)
|
self.weights_sd = load_file(file)
|
||||||
|
self.metadata = self.weights_sd.metadata()
|
||||||
else:
|
else:
|
||||||
self.weights_sd = torch.load(file, map_location='cpu')
|
self.weights_sd = torch.load(file, map_location='cpu')
|
||||||
|
with zipfile.ZipFile(file, "w") as zipf:
|
||||||
|
if "sd_scripts_metadata.json" in zipf.namelist():
|
||||||
|
with zipf.open("sd_scripts_metadata.json", "r") as jsfile:
|
||||||
|
self.metadata = json.load(jsfile)
|
||||||
|
|
||||||
def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None):
|
def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None):
|
||||||
if self.weights_sd:
|
if self.weights_sd:
|
||||||
@@ -174,7 +183,8 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
def get_trainable_params(self):
|
def get_trainable_params(self):
|
||||||
return self.parameters()
|
return self.parameters()
|
||||||
|
|
||||||
def save_weights(self, file, dtype):
|
def save_weights(self, file, dtype, metadata):
|
||||||
|
self.metadata = metadata
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
@@ -185,6 +195,8 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
|
|
||||||
if os.path.splitext(file)[1] == '.safetensors':
|
if os.path.splitext(file)[1] == '.safetensors':
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
save_file(state_dict, file)
|
save_file(state_dict, file, metadata)
|
||||||
else:
|
else:
|
||||||
torch.save(state_dict, file)
|
torch.save(state_dict, file)
|
||||||
|
with zipfile.ZipFile(file, "w") as zipf:
|
||||||
|
zipf.writestr("sd_scripts_metadata.json", json.dumps(metadata))
|
||||||
|
|||||||
@@ -197,6 +197,47 @@ def train(args):
|
|||||||
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||||
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"ss_learning_rate": args.learning_rate,
|
||||||
|
"ss_text_encoder_lr": args.text_encoder_lr,
|
||||||
|
"ss_unet_lr": args.unet_lr,
|
||||||
|
"ss_num_train_images": train_dataset.num_train_images,
|
||||||
|
"ss_num_reg_images": train_dataset.num_reg_images,
|
||||||
|
"ss_num_batches_per_epoch": len(train_dataloader),
|
||||||
|
"ss_num_epochs": num_train_epochs,
|
||||||
|
"ss_batch_size_per_device": args.train_batch_size,
|
||||||
|
"ss_total_batch_size": total_batch_size,
|
||||||
|
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
|
||||||
|
"ss_max_train_steps": args.max_train_steps,
|
||||||
|
"ss_lr_warmup_steps": args.lr_warmup_steps,
|
||||||
|
"ss_lr_scheduler": args.lr_scheduler,
|
||||||
|
"ss_network_module": args.network_module,
|
||||||
|
"ss_network_dim": 4 if args.network_dim is None else args.network_dim,
|
||||||
|
"ss_full_fp16": bool(args.full_fp16),
|
||||||
|
"ss_v2": bool(args.v2),
|
||||||
|
"ss_resolution": args.resolution,
|
||||||
|
"ss_clip_skip": args.clip_skip,
|
||||||
|
"ss_max_token_length": args.max_token_length,
|
||||||
|
"ss_color_aug": bool(args.color_aug),
|
||||||
|
"ss_flip_aug": bool(args.flip_aug),
|
||||||
|
"ss_random_crop": bool(args.random_crop),
|
||||||
|
"ss_shuffle_caption": bool(args.shuffle_caption),
|
||||||
|
"ss_cache_latents": bool(args.cache_latents),
|
||||||
|
"ss_enable_bucket": bool(args.enable_bucket),
|
||||||
|
"ss_min_bucket_reso": args.min_bucket_reso,
|
||||||
|
"ss_max_bucket_reso": args.max_bucket_reso,
|
||||||
|
"ss_seed": args.seed
|
||||||
|
}
|
||||||
|
|
||||||
|
if args.pretrained_model_name_or_path is not None:
|
||||||
|
sd_model_name = args.pretrained_model_name_or_path
|
||||||
|
if os.path.exists(sd_model_name):
|
||||||
|
metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name)
|
||||||
|
sd_model_name = os.path.basename(sd_model_name)
|
||||||
|
metadata["ss_sd_model_name"] = sd_model_name
|
||||||
|
|
||||||
|
metadata = {k: str(v) for k, v in metadata.items()}
|
||||||
|
|
||||||
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||||
global_step = 0
|
global_step = 0
|
||||||
|
|
||||||
@@ -296,7 +337,7 @@ def train(args):
|
|||||||
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
|
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
|
||||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||||
print(f"saving checkpoint: {ckpt_file}")
|
print(f"saving checkpoint: {ckpt_file}")
|
||||||
unwrap_model(network).save_weights(ckpt_file, save_dtype)
|
unwrap_model(network).save_weights(ckpt_file, save_dtype, metadata)
|
||||||
|
|
||||||
def remove_old_func(old_epoch_no):
|
def remove_old_func(old_epoch_no):
|
||||||
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
|
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
|
||||||
@@ -330,7 +371,7 @@ def train(args):
|
|||||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||||
|
|
||||||
print(f"save trained model to {ckpt_file}")
|
print(f"save trained model to {ckpt_file}")
|
||||||
network.save_weights(ckpt_file, save_dtype)
|
network.save_weights(ckpt_file, save_dtype, metadata)
|
||||||
print("model saved.")
|
print("model saved.")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user