mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Merge branch 'main' into textual_inversion
This commit is contained in:
@@ -632,7 +632,7 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
|
||||
del new_sd[ANOTHER_POSITION_IDS_KEY]
|
||||
else:
|
||||
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
|
||||
|
||||
|
||||
new_sd["text_model.embeddings.position_ids"] = position_ids
|
||||
return new_sd
|
||||
|
||||
@@ -886,7 +886,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
info = vae.load_state_dict(converted_vae_checkpoint)
|
||||
print("loadint vae:", info)
|
||||
print("loading vae:", info)
|
||||
|
||||
# convert text_model
|
||||
if v2:
|
||||
@@ -1105,12 +1105,12 @@ def load_vae(vae_id, dtype):
|
||||
|
||||
if vae_id.endswith(".bin"):
|
||||
# SD 1.5 VAE on Huggingface
|
||||
vae_sd = torch.load(vae_id, map_location="cpu")
|
||||
converted_vae_checkpoint = vae_sd
|
||||
converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
|
||||
else:
|
||||
# StableDiffusion
|
||||
vae_model = torch.load(vae_id, map_location="cpu")
|
||||
vae_sd = vae_model['state_dict']
|
||||
vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
|
||||
else torch.load(vae_id, map_location="cpu"))
|
||||
vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
|
||||
|
||||
# vae only or full model
|
||||
full_model = False
|
||||
@@ -1132,7 +1132,6 @@ def load_vae(vae_id, dtype):
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
return vae
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
|
||||
@@ -765,6 +765,20 @@ def exists(val):
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
|
||||
def model_hash(filename):
|
||||
try:
|
||||
with open(filename, "rb") as file:
|
||||
import hashlib
|
||||
m = hashlib.sha256()
|
||||
|
||||
file.seek(0x100000)
|
||||
m.update(file.read(0x10000))
|
||||
return m.hexdigest()[0:8]
|
||||
except FileNotFoundError:
|
||||
return 'NOFILE'
|
||||
|
||||
|
||||
# flash attention forwards and backwards
|
||||
|
||||
# https://arxiv.org/abs/2205.14135
|
||||
@@ -1051,6 +1065,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
|
||||
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
||||
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
||||
parser.add_argument("--max_train_epochs", type=int, default=None, help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
|
||||
parser.add_argument("--max_data_loader_n_workers", type=int, default=8, help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)")
|
||||
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
|
||||
parser.add_argument("--gradient_checkpointing", action="store_true",
|
||||
help="enable gradient checkpointing / grandient checkpointingを有効にする")
|
||||
|
||||
Reference in New Issue
Block a user