mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge branch 'main' into caption-frequency-metadata
This commit is contained in:
@@ -12,6 +12,7 @@ import math
|
||||
import os
|
||||
import random
|
||||
import hashlib
|
||||
from io import BytesIO
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
@@ -25,6 +26,7 @@ from PIL import Image
|
||||
import cv2
|
||||
from einops import rearrange
|
||||
from torch import einsum
|
||||
import safetensors.torch
|
||||
|
||||
import library.model_util as model_util
|
||||
|
||||
@@ -86,6 +88,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.min_bucket_reso = None
|
||||
self.max_bucket_reso = None
|
||||
self.tag_frequency = {}
|
||||
self.bucket_info = None
|
||||
|
||||
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
||||
|
||||
@@ -111,9 +114,14 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
self.image_data: dict[str, ImageInfo] = {}
|
||||
|
||||
self.replacements = {}
|
||||
|
||||
def disable_token_padding(self):
|
||||
self.token_padding_disabled = True
|
||||
|
||||
def add_replacement(self, str_from, str_to):
|
||||
self.replacements[str_from] = str_to
|
||||
|
||||
def process_caption(self, caption):
|
||||
if self.shuffle_caption:
|
||||
tokens = caption.strip().split(",")
|
||||
@@ -126,6 +134,17 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
random.shuffle(tokens)
|
||||
tokens = keep_tokens + tokens
|
||||
caption = ",".join(tokens).strip()
|
||||
|
||||
for str_from, str_to in self.replacements.items():
|
||||
if str_from == "":
|
||||
# replace all
|
||||
if type(str_to) == list:
|
||||
caption = random.choice(str_to)
|
||||
else:
|
||||
caption = str_to
|
||||
else:
|
||||
caption = caption.replace(str_from, str_to)
|
||||
|
||||
return caption
|
||||
|
||||
def get_input_ids(self, caption):
|
||||
@@ -218,11 +237,17 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.buckets[bucket_index].append(image_info.image_key)
|
||||
|
||||
if self.enable_bucket:
|
||||
self.bucket_info = {"buckets": {}}
|
||||
print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)")
|
||||
for i, (reso, img_keys) in enumerate(zip(bucket_resos, self.buckets)):
|
||||
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(img_keys)}
|
||||
print(f"bucket {i}: resolution {reso}, count: {len(img_keys)}")
|
||||
|
||||
img_ar_errors = np.array(img_ar_errors)
|
||||
print(f"mean ar error (without repeats): {np.mean(np.abs(img_ar_errors))}")
|
||||
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
|
||||
self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
|
||||
print(f"mean ar error (without repeats): {mean_img_ar_error}")
|
||||
|
||||
|
||||
# 参照用indexを作る
|
||||
self.buckets_indices: list(BucketBatchIndex) = []
|
||||
@@ -609,7 +634,7 @@ class FineTuningDataset(BaseDataset):
|
||||
else:
|
||||
# わりといい加減だがいい方法が思いつかん
|
||||
abs_path = glob_images(train_data_dir, image_key)
|
||||
assert len(abs_path) >= 1, f"no image / 画像がありません: {abs_path}"
|
||||
assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
|
||||
abs_path = abs_path[0]
|
||||
|
||||
caption = img_md.get('caption')
|
||||
@@ -716,15 +741,17 @@ class FineTuningDataset(BaseDataset):
|
||||
return npz_file_norm, npz_file_flip
|
||||
|
||||
|
||||
def debug_dataset(train_dataset):
|
||||
def debug_dataset(train_dataset, show_input_ids=False):
|
||||
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
||||
print("Escape for exit. / Escキーで中断、終了します")
|
||||
k = 0
|
||||
for example in train_dataset:
|
||||
if example['latents'] is not None:
|
||||
print("sample has latents from npz file")
|
||||
for j, (ik, cap, lw) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'])):
|
||||
for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
|
||||
print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, caption: "{cap}", loss weight: {lw}')
|
||||
if show_input_ids:
|
||||
print(f"input ids: {iid}")
|
||||
if example['images'] is not None:
|
||||
im = example['images'][j]
|
||||
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
|
||||
@@ -800,6 +827,49 @@ def calculate_sha256(filename):
|
||||
return hash_sha256.hexdigest()
|
||||
|
||||
|
||||
def precalculate_safetensors_hashes(tensors, metadata):
|
||||
"""Precalculate the model hashes needed by sd-webui-additional-networks to
|
||||
save time on indexing the model later."""
|
||||
|
||||
# Because writing user metadata to the file can change the result of
|
||||
# sd_models.model_hash(), only retain the training metadata for purposes of
|
||||
# calculating the hash, as they are meant to be immutable
|
||||
metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
|
||||
|
||||
bytes = safetensors.torch.save(tensors, metadata)
|
||||
b = BytesIO(bytes)
|
||||
|
||||
model_hash = addnet_hash_safetensors(b)
|
||||
legacy_hash = addnet_hash_legacy(b)
|
||||
return model_hash, legacy_hash
|
||||
|
||||
|
||||
def addnet_hash_legacy(b):
|
||||
"""Old model hash used by sd-webui-additional-networks for .safetensors format files"""
|
||||
m = hashlib.sha256()
|
||||
|
||||
b.seek(0x100000)
|
||||
m.update(b.read(0x10000))
|
||||
return m.hexdigest()[0:8]
|
||||
|
||||
|
||||
def addnet_hash_safetensors(b):
|
||||
"""New model hash used by sd-webui-additional-networks for .safetensors format files"""
|
||||
hash_sha256 = hashlib.sha256()
|
||||
blksize = 1024 * 1024
|
||||
|
||||
b.seek(0)
|
||||
header = b.read(8)
|
||||
n = int.from_bytes(header, "little")
|
||||
|
||||
offset = n + 8
|
||||
b.seek(offset)
|
||||
for chunk in iter(lambda: b.read(blksize), b""):
|
||||
hash_sha256.update(chunk)
|
||||
|
||||
return hash_sha256.hexdigest()
|
||||
|
||||
|
||||
# flash attention forwards and backwards
|
||||
|
||||
# https://arxiv.org/abs/2205.14135
|
||||
@@ -1067,6 +1137,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する")
|
||||
parser.add_argument("--save_every_n_epochs", type=int, default=None,
|
||||
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
|
||||
parser.add_argument("--save_n_epoch_ratio", type=int, default=None,
|
||||
help="save checkpoint N epoch ratio (for example 5 means save at least 5 files total) / 学習中のモデルを指定のエポック割合で保存する(たとえば5を指定すると最低5個のファイルが保存される)")
|
||||
parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する")
|
||||
parser.add_argument("--save_last_n_epochs_state", type=int, default=None,
|
||||
help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)")
|
||||
|
||||
Reference in New Issue
Block a user