mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge branch 'main' into textual_inversion
This commit is contained in:
@@ -11,6 +11,8 @@ import glob
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import hashlib
|
||||
from io import BytesIO
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
@@ -24,6 +26,7 @@ from PIL import Image
|
||||
import cv2
|
||||
from einops import rearrange
|
||||
from torch import einsum
|
||||
import safetensors.torch
|
||||
|
||||
import library.model_util as model_util
|
||||
|
||||
@@ -79,6 +82,12 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.debug_dataset = debug_dataset
|
||||
self.random_crop = random_crop
|
||||
self.token_padding_disabled = False
|
||||
self.dataset_dirs_info = {}
|
||||
self.reg_dataset_dirs_info = {}
|
||||
self.enable_bucket = False
|
||||
self.min_bucket_reso = None
|
||||
self.max_bucket_reso = None
|
||||
self.bucket_info = None
|
||||
|
||||
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
||||
|
||||
@@ -227,11 +236,17 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.buckets[bucket_index].append(image_info.image_key)
|
||||
|
||||
if self.enable_bucket:
|
||||
self.bucket_info = {"buckets": {}}
|
||||
print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)")
|
||||
for i, (reso, img_keys) in enumerate(zip(bucket_resos, self.buckets)):
|
||||
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(img_keys)}
|
||||
print(f"bucket {i}: resolution {reso}, count: {len(img_keys)}")
|
||||
|
||||
img_ar_errors = np.array(img_ar_errors)
|
||||
print(f"mean ar error (without repeats): {np.mean(np.abs(img_ar_errors))}")
|
||||
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
|
||||
self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
|
||||
print(f"mean ar error (without repeats): {mean_img_ar_error}")
|
||||
|
||||
|
||||
# 参照用indexを作る
|
||||
self.buckets_indices: list(BucketBatchIndex) = []
|
||||
@@ -479,6 +494,8 @@ class DreamBoothDataset(BaseDataset):
|
||||
assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
|
||||
self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions(
|
||||
(self.width, self.height), min_bucket_reso, max_bucket_reso)
|
||||
self.min_bucket_reso = min_bucket_reso
|
||||
self.max_bucket_reso = max_bucket_reso
|
||||
else:
|
||||
self.bucket_resos = [(self.width, self.height)]
|
||||
self.bucket_aspect_ratios = [self.width / self.height]
|
||||
@@ -539,6 +556,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
for img_path, caption in zip(img_paths, captions):
|
||||
info = ImageInfo(img_path, n_repeats, caption, False, img_path)
|
||||
self.register_image(info)
|
||||
self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
||||
print(f"{num_train_images} train images with repeating.")
|
||||
self.num_train_images = num_train_images
|
||||
|
||||
@@ -555,6 +573,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
for img_path, caption in zip(img_paths, captions):
|
||||
info = ImageInfo(img_path, n_repeats, caption, True, img_path)
|
||||
reg_infos.append(info)
|
||||
self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
||||
|
||||
print(f"{num_reg_images} reg images.")
|
||||
if num_train_images < num_reg_images:
|
||||
@@ -627,6 +646,8 @@ class FineTuningDataset(BaseDataset):
|
||||
self.num_train_images = len(metadata) * dataset_repeats
|
||||
self.num_reg_images = 0
|
||||
|
||||
self.dataset_dirs_info[os.path.basename(self.train_data_dir)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
|
||||
|
||||
# check existence of all npz files
|
||||
if not self.color_aug:
|
||||
npz_any = False
|
||||
@@ -669,6 +690,8 @@ class FineTuningDataset(BaseDataset):
|
||||
assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
|
||||
self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions(
|
||||
(self.width, self.height), min_bucket_reso, max_bucket_reso)
|
||||
self.min_bucket_reso = min_bucket_reso
|
||||
self.max_bucket_reso = max_bucket_reso
|
||||
else:
|
||||
self.bucket_resos = [(self.width, self.height)]
|
||||
self.bucket_aspect_ratios = [self.width / self.height]
|
||||
@@ -681,6 +704,9 @@ class FineTuningDataset(BaseDataset):
|
||||
self.bucket_resos.sort()
|
||||
self.bucket_aspect_ratios = [w / h for w, h in self.bucket_resos]
|
||||
|
||||
self.min_bucket_reso = min([min(reso) for reso in resos])
|
||||
self.max_bucket_reso = max([max(reso) for reso in resos])
|
||||
|
||||
def image_key_to_npz_file(self, image_key):
|
||||
base_name = os.path.splitext(image_key)[0]
|
||||
npz_file_norm = base_name + '.npz'
|
||||
@@ -767,9 +793,9 @@ def default(val, d):
|
||||
|
||||
|
||||
def model_hash(filename):
|
||||
"""Old model hash used by stable-diffusion-webui"""
|
||||
try:
|
||||
with open(filename, "rb") as file:
|
||||
import hashlib
|
||||
m = hashlib.sha256()
|
||||
|
||||
file.seek(0x100000)
|
||||
@@ -779,6 +805,61 @@ def model_hash(filename):
|
||||
return 'NOFILE'
|
||||
|
||||
|
||||
def calculate_sha256(filename):
|
||||
"""New model hash used by stable-diffusion-webui"""
|
||||
hash_sha256 = hashlib.sha256()
|
||||
blksize = 1024 * 1024
|
||||
|
||||
with open(filename, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(blksize), b""):
|
||||
hash_sha256.update(chunk)
|
||||
|
||||
return hash_sha256.hexdigest()
|
||||
|
||||
|
||||
def precalculate_safetensors_hashes(tensors, metadata):
|
||||
"""Precalculate the model hashes needed by sd-webui-additional-networks to
|
||||
save time on indexing the model later."""
|
||||
|
||||
# Because writing user metadata to the file can change the result of
|
||||
# sd_models.model_hash(), only retain the training metadata for purposes of
|
||||
# calculating the hash, as they are meant to be immutable
|
||||
metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
|
||||
|
||||
bytes = safetensors.torch.save(tensors, metadata)
|
||||
b = BytesIO(bytes)
|
||||
|
||||
model_hash = addnet_hash_safetensors(b)
|
||||
legacy_hash = addnet_hash_legacy(b)
|
||||
return model_hash, legacy_hash
|
||||
|
||||
|
||||
def addnet_hash_legacy(b):
|
||||
"""Old model hash used by sd-webui-additional-networks for .safetensors format files"""
|
||||
m = hashlib.sha256()
|
||||
|
||||
b.seek(0x100000)
|
||||
m.update(b.read(0x10000))
|
||||
return m.hexdigest()[0:8]
|
||||
|
||||
|
||||
def addnet_hash_safetensors(b):
|
||||
"""New model hash used by sd-webui-additional-networks for .safetensors format files"""
|
||||
hash_sha256 = hashlib.sha256()
|
||||
blksize = 1024 * 1024
|
||||
|
||||
b.seek(0)
|
||||
header = b.read(8)
|
||||
n = int.from_bytes(header, "little")
|
||||
|
||||
offset = n + 8
|
||||
b.seek(offset)
|
||||
for chunk in iter(lambda: b.read(blksize), b""):
|
||||
hash_sha256.update(chunk)
|
||||
|
||||
return hash_sha256.hexdigest()
|
||||
|
||||
|
||||
# flash attention forwards and backwards
|
||||
|
||||
# https://arxiv.org/abs/2205.14135
|
||||
@@ -1046,7 +1127,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する")
|
||||
parser.add_argument("--save_every_n_epochs", type=int, default=None,
|
||||
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
|
||||
parser.add_argument("--save_n_epoch_ratio", type=int, default=None,
|
||||
help="save checkpoint N epoch ratio (for example 5 means save at least 5 files total) / 学習中のモデルを指定のエポック割合で保存する(たとえば5を指定すると最低5個のファイルが保存される)")
|
||||
parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する")
|
||||
parser.add_argument("--save_last_n_epochs_state", type=int, default=None,
|
||||
help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)")
|
||||
parser.add_argument("--save_state", action="store_true",
|
||||
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
|
||||
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
|
||||
@@ -1065,8 +1150,10 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
|
||||
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
||||
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
||||
parser.add_argument("--max_train_epochs", type=int, default=None, help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
|
||||
parser.add_argument("--max_data_loader_n_workers", type=int, default=8, help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)")
|
||||
parser.add_argument("--max_train_epochs", type=int, default=None,
|
||||
help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
|
||||
parser.add_argument("--max_data_loader_n_workers", type=int, default=8,
|
||||
help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)")
|
||||
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
|
||||
parser.add_argument("--gradient_checkpointing", action="store_true",
|
||||
help="enable gradient checkpointing / grandient checkpointingを有効にする")
|
||||
@@ -1316,7 +1403,6 @@ def get_epoch_ckpt_name(args: argparse.Namespace, use_safetensors, epoch):
|
||||
|
||||
def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoch_no: int, num_train_epochs: int):
|
||||
saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs
|
||||
remove_epoch_no = None
|
||||
if saving:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
save_func()
|
||||
@@ -1324,7 +1410,7 @@ def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoc
|
||||
if args.save_last_n_epochs is not None:
|
||||
remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs
|
||||
remove_old_func(remove_epoch_no)
|
||||
return saving, remove_epoch_no
|
||||
return saving
|
||||
|
||||
|
||||
def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, num_train_epochs: int, global_step: int, text_encoder, unet, vae):
|
||||
@@ -1364,15 +1450,18 @@ def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path:
|
||||
save_func = save_du
|
||||
remove_old_func = remove_du
|
||||
|
||||
saving, remove_epoch_no = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs)
|
||||
saving = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs)
|
||||
if saving and args.save_state:
|
||||
save_state_on_epoch_end(args, accelerator, model_name, epoch_no, remove_epoch_no)
|
||||
save_state_on_epoch_end(args, accelerator, model_name, epoch_no)
|
||||
|
||||
|
||||
def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no, remove_epoch_no):
|
||||
def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no):
|
||||
print("saving state.")
|
||||
accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)))
|
||||
if remove_epoch_no is not None:
|
||||
|
||||
last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
|
||||
if last_n_epochs is not None:
|
||||
remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs
|
||||
state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no))
|
||||
if os.path.exists(state_dir_old):
|
||||
print(f"removing old state: {state_dir_old}")
|
||||
|
||||
Reference in New Issue
Block a user