Compare commits

..

6 Commits

Author SHA1 Message Date
Kohya S
f33e155c5b simplyfy code by updating accelerate to 0.30.0 2024-05-12 15:48:40 +09:00
Kohya S
c1ef6dcabc fix to work with wrapped optimizer by accelerate 2024-05-06 13:17:14 +09:00
Kohya S
5fe9ded188 simplify codes for schedule free optimizer 2024-05-04 21:03:47 +09:00
青龍聖者@bdsqlsz
c68712635c Support new optimizer Schedule free (#1250)
* init

* use no schedule

* fix typo

* update for eval()

* fix typo
2024-05-04 18:56:27 +09:00
Kohya S
0540c33aca pop weights if available #1247 2024-04-21 17:45:29 +09:00
Kohya S
52652cba1a disable main process check for deepspeed #1247 2024-04-21 17:41:32 +09:00
19 changed files with 337 additions and 627 deletions

View File

@@ -18,4 +18,4 @@ jobs:
- uses: actions/checkout@v4
- name: typos-action
uses: crate-ci/typos@v1.28.1
uses: crate-ci/typos@v1.19.0

View File

@@ -3,10 +3,6 @@ Stable Diffusionの学習、画像生成、その他のスクリプトを入れ
[README in English](./README.md) ←更新情報はこちらにあります
開発中のバージョンはdevブランチにあります。最新の変更点はdevブランチをご確認ください。
FLUX.1およびSD3/SD3.5対応はsd3ブランチで行っています。それらの学習を行う場合はsd3ブランチをご利用ください。
GUIやPowerShellスクリプトなど、より使いやすくする機能が[bmaltais氏のリポジトリ](https://github.com/bmaltais/kohya_ss)で提供されています英語ですのであわせてご覧ください。bmaltais氏に感謝します。
以下のスクリプトがあります。

View File

@@ -5,11 +5,6 @@ This repository contains training, generation and utility scripts for Stable Dif
[日本語版READMEはこちら](./README-ja.md)
The development version is in the `dev` branch. Please check the dev branch for the latest changes.
FLUX.1 and SD3/SD3.5 support is done in the `sd3` branch. If you want to train them, please use the sd3 branch.
For easier use (GUI and PowerShell scripts etc...), please visit [the repository maintained by bmaltais](https://github.com/bmaltais/kohya_ss). Thanks to @bmaltais!
This repository contains the scripts for:
@@ -142,41 +137,6 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
## Change History
### Oct 27, 2024 / 2024-10-27:
- `svd_merge_lora.py` VRAM usage has been reduced. However, main memory usage will increase (32GB is sufficient).
- This will be included in the next release.
- `svd_merge_lora.py` のVRAM使用量を削減しました。ただし、メインメモリの使用量は増加します32GBあれば十分です
- これは次回リリースに含まれます。
### Oct 26, 2024 / 2024-10-26:
- Fixed a bug in `svd_merge_lora.py`, `sdxl_merge_lora.py`, and `resize_lora.py` where the hash value of LoRA metadata was not correctly calculated when the `save_precision` was different from the `precision` used in the calculation. See issue [#1722](https://github.com/kohya-ss/sd-scripts/pull/1722) for details. Thanks to JujoHotaru for raising the issue.
- It will be included in the next release.
- `svd_merge_lora.py`、`sdxl_merge_lora.py`、`resize_lora.py`で、保存時の精度が計算時の精度と異なる場合、LoRAメタデータのハッシュ値が正しく計算されない不具合を修正しました。詳細は issue [#1722](https://github.com/kohya-ss/sd-scripts/pull/1722) をご覧ください。問題提起していただいた JujoHotaru 氏に感謝します。
- 以上は次回リリースに含まれます。
### Sep 13, 2024 / 2024-09-13:
- `sdxl_merge_lora.py` now supports OFT. Thanks to Maru-mee for the PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580).
- `svd_merge_lora.py` now supports LBW. Thanks to terracottahaniwa. See PR [#1575](https://github.com/kohya-ss/sd-scripts/pull/1575) for details.
- `sdxl_merge_lora.py` also supports LBW.
- See [LoRA Block Weight](https://github.com/hako-mikan/sd-webui-lora-block-weight) by hako-mikan for details on LBW.
- These will be included in the next release.
- `sdxl_merge_lora.py` が OFT をサポートされました。PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580) Maru-mee 氏に感謝します。
- `svd_merge_lora.py` で LBW がサポートされました。PR [#1575](https://github.com/kohya-ss/sd-scripts/pull/1575) terracottahaniwa 氏に感謝します。
- `sdxl_merge_lora.py` でも LBW がサポートされました。
- LBW の詳細は hako-mikan 氏の [LoRA Block Weight](https://github.com/hako-mikan/sd-webui-lora-block-weight) をご覧ください。
- 以上は次回リリースに含まれます。
### Jun 23, 2024 / 2024-06-23:
- Fixed `cache_latents.py` and `cache_text_encoder_outputs.py` not working. (Will be included in the next release.)
- `cache_latents.py` および `cache_text_encoder_outputs.py` が動作しなくなっていたのを修正しました。(次回リリースに含まれます。)
### Apr 7, 2024 / 2024-04-07: v0.8.7
- The default value of `huber_schedule` in Scheduled Huber Loss is changed from `exponential` to `snr`, which is expected to give better results.

View File

@@ -250,23 +250,32 @@ def train(args):
unet.to(weight_dtype)
text_encoder.to(weight_dtype)
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
if args.deepspeed:
if args.train_text_encoder:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
else:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
ds_model, optimizer, train_dataloader = accelerator.prepare(ds_model, optimizer, train_dataloader)
if not use_schedule_free_optimizer:
lr_scheduler = accelerator.prepare(lr_scheduler)
training_models = [ds_model]
else:
# acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(unet, text_encoder, optimizer, train_dataloader)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
if not use_schedule_free_optimizer:
lr_scheduler = accelerator.prepare(lr_scheduler)
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: optimizer.train()
optimizer_eval_if_needed = lambda: optimizer.eval()
else:
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
@@ -324,6 +333,7 @@ def train(args):
m.train()
for step, batch in enumerate(train_dataloader):
optimizer_train_if_needed()
current_step.value = global_step
with accelerator.accumulate(*training_models):
with torch.no_grad():
@@ -354,7 +364,9 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)
# Predict the noise residual
with accelerator.autocast():
@@ -368,7 +380,9 @@ def train(args):
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
loss = loss.mean([1, 2, 3])
if args.min_snr_gamma:
@@ -380,7 +394,9 @@ def train(args):
loss = loss.mean() # mean over batch dimension
else:
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
)
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
@@ -390,9 +406,11 @@ def train(args):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
lr_scheduler.step() # if schedule-free optimizer is used, this is a no-op
optimizer.zero_grad(set_to_none=True)
optimizer_eval_if_needed()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
@@ -471,7 +489,7 @@ def train(args):
accelerator.end_training()
if is_main_process and (args.save_state or args.save_state_on_train_end):
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す

View File

@@ -1489,7 +1489,7 @@ class DreamBoothDataset(BaseDataset):
def load_dreambooth_dir(subset: DreamBoothSubset):
if not os.path.isdir(subset.image_dir):
logger.warning(f"not directory: {subset.image_dir}")
return [], [], []
return [], []
info_cache_file = os.path.join(subset.image_dir, self.IMAGE_INFO_CACHE_FILE)
use_cached_info_for_subset = subset.cache_info
@@ -3087,7 +3087,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
)
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
parser.add_argument(
"--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / grandient checkpointingを有効にする"
"--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする"
)
parser.add_argument(
"--gradient_accumulation_steps",
@@ -4088,6 +4088,21 @@ def get_optimizer(args, trainable_params):
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type.endswith("schedulefree".lower()):
try:
import schedulefree as sf
except ImportError:
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")
if optimizer_type == "AdamWScheduleFree".lower():
optimizer_class = sf.AdamWScheduleFree
logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}")
elif optimizer_type == "SGDScheduleFree".lower():
optimizer_class = sf.SGDScheduleFree
logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}")
else:
raise ValueError(f"Unknown optimizer type: {optimizer_type}")
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
if optimizer is None:
# 任意のoptimizerを使う
optimizer_type = args.optimizer_type # lowerでないやつ微妙
@@ -4116,6 +4131,14 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
"""
Unified API to get any scheduler from its name.
"""
# supports schedule free optimizer
if args.optimizer_type.lower().endswith("schedulefree"):
# return dummy scheduler: it has 'step' method but does nothing
logger.info("use dummy scheduler for schedule free optimizer / schedule free optimizer用のダミースケジューラを使用します")
lr_scheduler = TYPE_TO_SCHEDULER_FUNCTION[SchedulerType.CONSTANT](optimizer)
lr_scheduler.step = lambda: None
return lr_scheduler
name = args.lr_scheduler
num_warmup_steps: Optional[int] = args.lr_warmup_steps
num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps
@@ -4250,7 +4273,7 @@ def load_tokenizer(args: argparse.Namespace):
return tokenizer
def prepare_accelerator(args: argparse.Namespace):
def prepare_accelerator(args: argparse.Namespace) -> Accelerator:
"""
this function also prepares deepspeed plugin
"""

View File

@@ -39,7 +39,12 @@ def load_state_dict(file_name, dtype):
return sd, metadata
def save_to_file(file_name, state_dict, metadata):
def save_to_file(file_name, state_dict, dtype, metadata):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
if model_util.is_safetensors(file_name):
save_file(state_dict, file_name, metadata)
else:
@@ -344,18 +349,12 @@ def resize(args):
metadata["ss_network_dim"] = "Dynamic"
metadata["ss_network_alpha"] = "Dynamic"
# cast to save_dtype before calculating hashes
for key in list(state_dict.keys()):
value = state_dict[key]
if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype:
state_dict[key] = value.to(save_dtype)
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, metadata)
save_to_file(args.save_to, state_dict, save_dtype, metadata)
def setup_parser() -> argparse.ArgumentParser:

View File

@@ -1,25 +1,18 @@
import itertools
import math
import argparse
import os
import time
import concurrent.futures
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from library import sai_model_spec, sdxl_model_util, train_util
import library.model_util as model_util
import lora
import oft
from svd_merge_lora import format_lbws, get_lbw_block_index, LAYER26
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == ".safetensors":
sd = load_file(file_name)
@@ -35,58 +28,36 @@ def load_state_dict(file_name, dtype):
return sd, metadata
def save_to_file(file_name, model, metadata):
def save_to_file(file_name, model, state_dict, dtype, metadata):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == ".safetensors":
save_file(model, file_name, metadata=metadata)
else:
torch.save(model, file_name)
def detect_method_from_training_model(models, dtype):
for model in models:
# TODO It is better to use key names to detect the method
lora_sd, _ = load_state_dict(model, dtype)
for key in tqdm(lora_sd.keys()):
if "lora_up" in key or "lora_down" in key:
return "LoRA"
elif "oft_blocks" in key:
return "OFT"
def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, lbws, merge_dtype):
def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype):
text_encoder1.to(merge_dtype)
text_encoder1.to(merge_dtype)
text_encoder2.to(merge_dtype)
unet.to(merge_dtype)
# detect the method: OFT or LoRA_module
method = detect_method_from_training_model(models, merge_dtype)
logger.info(f"method:{method}")
if lbws:
lbws, _, LBW_TARGET_IDX = format_lbws(lbws)
else:
LBW_TARGET_IDX = []
# create module map
name_to_module = {}
for i, root_module in enumerate([text_encoder1, text_encoder2, unet]):
if method == "LoRA":
if i <= 1:
if i == 0:
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1
else:
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
if i <= 1:
if i == 0:
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1
else:
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
target_replace_modules = (
lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
)
elif method == "OFT":
prefix = oft.OFTNetwork.OFT_PREFIX_UNET
# ALL_LINEAR includes ATTN_ONLY, so we don't need to specify ATTN_ONLY
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
else:
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
target_replace_modules = (
oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR + oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
)
for name, module in root_module.named_modules():
@@ -97,172 +68,65 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, lbws,
lora_name = lora_name.replace(".", "_")
name_to_module[lora_name] = child_module
for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws):
for model, ratio in zip(models, ratios):
logger.info(f"loading: {model}")
lora_sd, _ = load_state_dict(model, merge_dtype)
logger.info(f"merging...")
for key in tqdm(lora_sd.keys()):
if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[: key.index("lora_down")] + "alpha"
if lbw:
lbw_weights = [1] * 26
for index, value in zip(LBW_TARGET_IDX, lbw):
lbw_weights[index] = value
logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}")
if method == "LoRA":
for key in tqdm(lora_sd.keys()):
if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[: key.index("lora_down")] + "alpha"
# find original module for this lora
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
if module_name not in name_to_module:
logger.info(f"no module found for LoRA weight: {key}")
continue
module = name_to_module[module_name]
# logger.info(f"apply {key} to {module}")
down_weight = lora_sd[key]
up_weight = lora_sd[up_key]
dim = down_weight.size()[0]
alpha = lora_sd.get(alpha_key, dim)
scale = alpha / dim
if lbw:
index = get_lbw_block_index(key, True)
is_lbw_target = index in LBW_TARGET_IDX
if is_lbw_target:
scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける
# W <- W + U * D
weight = module.weight
# logger.info(module_name, down_weight.size(), up_weight.size())
if len(weight.size()) == 2:
# linear
weight = weight + ratio * (up_weight @ down_weight) * scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
weight
+ ratio
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + ratio * conved * scale
module.weight = torch.nn.Parameter(weight)
elif method == "OFT":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for key in tqdm(lora_sd.keys()):
if "oft_blocks" in key:
oft_blocks = lora_sd[key]
dim = oft_blocks.shape[0]
break
for key in tqdm(lora_sd.keys()):
if "alpha" in key:
oft_blocks = lora_sd[key]
alpha = oft_blocks.item()
break
def merge_to(key):
if "alpha" in key:
return
# find original module for this OFT
module_name = ".".join(key.split(".")[:-1])
# find original module for this lora
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
if module_name not in name_to_module:
logger.info(f"no module found for OFT weight: {key}")
return
logger.info(f"no module found for LoRA weight: {key}")
continue
module = name_to_module[module_name]
# logger.info(f"apply {key} to {module}")
oft_blocks = lora_sd[key]
down_weight = lora_sd[key]
up_weight = lora_sd[up_key]
if isinstance(module, torch.nn.Linear):
out_dim = module.out_features
elif isinstance(module, torch.nn.Conv2d):
out_dim = module.out_channels
dim = down_weight.size()[0]
alpha = lora_sd.get(alpha_key, dim)
scale = alpha / dim
num_blocks = dim
block_size = out_dim // dim
constraint = (0 if alpha is None else alpha) * out_dim
multiplier = 1
if lbw:
index = get_lbw_block_index(key, False)
is_lbw_target = index in LBW_TARGET_IDX
if is_lbw_target:
multiplier *= lbw_weights[index]
block_Q = oft_blocks - oft_blocks.transpose(1, 2)
norm_Q = torch.norm(block_Q.flatten())
new_norm_Q = torch.clamp(norm_Q, max=constraint)
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
I = torch.eye(block_size, device=oft_blocks.device).unsqueeze(0).repeat(num_blocks, 1, 1)
block_R = torch.matmul(I + block_Q, (I - block_Q).inverse())
block_R_weighted = multiplier * block_R + (1 - multiplier) * I
R = torch.block_diag(*block_R_weighted)
# get org weight
org_sd = module.state_dict()
org_weight = org_sd["weight"].to(device)
R = R.to(org_weight.device, dtype=org_weight.dtype)
if org_weight.dim() == 4:
weight = torch.einsum("oihw, op -> pihw", org_weight, R)
# W <- W + U * D
weight = module.weight
# logger.info(module_name, down_weight.size(), up_weight.size())
if len(weight.size()) == 2:
# linear
weight = weight + ratio * (up_weight @ down_weight) * scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
weight
+ ratio
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
else:
weight = torch.einsum("oi, op -> pi", org_weight, R)
weight = weight.contiguous() # Make Tensor contiguous; required due to ThreadPoolExecutor
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + ratio * conved * scale
module.weight = torch.nn.Parameter(weight)
# TODO multi-threading may cause OOM on CPU if cpu_count is too high and RAM is not enough
max_workers = 1 if device.type != "cpu" else None # avoid OOM on GPU
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys())))
def merge_lora_models(models, ratios, lbws, merge_dtype, concat=False, shuffle=False):
def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
base_alphas = {} # alpha for merged model
base_dims = {}
# detect the method: OFT or LoRA_module
method = detect_method_from_training_model(models, merge_dtype)
if method == "OFT":
raise ValueError(
"OFT model is not supported for merging OFT models. / OFTモデルはOFTモデル同士のマージには対応していません"
)
if lbws:
lbws, _, LBW_TARGET_IDX = format_lbws(lbws)
else:
LBW_TARGET_IDX = []
merged_sd = {}
v2 = None
base_model = None
for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws):
for model, ratio in zip(models, ratios):
logger.info(f"loading: {model}")
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
if lbw:
lbw_weights = [1] * 26
for index, value in zip(LBW_TARGET_IDX, lbw):
lbw_weights[index] = value
logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}")
if lora_metadata is not None:
if v2 is None:
v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # returns string, SDXLはv2がないのでFalseのはず
@@ -300,7 +164,7 @@ def merge_lora_models(models, ratios, lbws, merge_dtype, concat=False, shuffle=F
for key in tqdm(lora_sd.keys()):
if "alpha" in key:
continue
if "lora_up" in key and concat:
concat_dim = 1
elif "lora_down" in key and concat:
@@ -314,14 +178,8 @@ def merge_lora_models(models, ratios, lbws, merge_dtype, concat=False, shuffle=F
alpha = alphas[lora_module_name]
scale = math.sqrt(alpha / base_alpha) * ratio
scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
if lbw:
index = get_lbw_block_index(key, True)
is_lbw_target = index in LBW_TARGET_IDX
if is_lbw_target:
scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける
scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
if key in merged_sd:
assert (
merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
@@ -343,7 +201,7 @@ def merge_lora_models(models, ratios, lbws, merge_dtype, concat=False, shuffle=F
dim = merged_sd[key_down].shape[0]
perm = torch.randperm(dim)
merged_sd[key_down] = merged_sd[key_down][perm]
merged_sd[key_up] = merged_sd[key_up][:, perm]
merged_sd[key_up] = merged_sd[key_up][:,perm]
logger.info("merged model")
logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
@@ -371,15 +229,7 @@ def merge_lora_models(models, ratios, lbws, merge_dtype, concat=False, shuffle=F
def merge(args):
assert len(args.models) == len(
args.ratios
), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
if args.lbws:
assert len(args.models) == len(
args.lbws
), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください"
else:
args.lbws = [] # zip_longestで扱えるようにlbws未使用時には空のリストにしておく
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
def str_to_dtype(p):
if p == "float":
@@ -407,7 +257,7 @@ def merge(args):
ckpt_info,
) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.sd_model, "cpu")
merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, args.lbws, merge_dtype)
merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, merge_dtype)
if args.no_metadata:
sai_metadata = None
@@ -423,13 +273,7 @@ def merge(args):
args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype
)
else:
state_dict, metadata = merge_lora_models(args.models, args.ratios, args.lbws, merge_dtype, args.concat, args.shuffle)
# cast to save_dtype before calculating hashes
for key in list(state_dict.keys()):
value = state_dict[key]
if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype:
state_dict[key] = value.to(save_dtype)
state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
logger.info(f"calculating hashes and creating metadata...")
@@ -446,7 +290,7 @@ def merge(args):
metadata.update(sai_metadata)
logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, metadata)
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
def setup_parser() -> argparse.ArgumentParser:
@@ -472,19 +316,12 @@ def setup_parser() -> argparse.ArgumentParser:
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする",
)
parser.add_argument(
"--save_to",
type=str,
default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
)
parser.add_argument(
"--models",
type=str,
nargs="*",
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors",
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
)
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
parser.add_argument("--lbws", type=str, nargs="*", help="lbw for each model / それぞれのLoRAモデルの層別適用率")
parser.add_argument(
"--no_metadata",
action="store_true",
@@ -500,7 +337,8 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--shuffle",
action="store_true",
help="shuffle lora weight./ " + "LoRAの重みをシャッフルする",
help="shuffle lora weight./ "
+ "LoRAの重みをシャッフルする",
)
return parser

View File

@@ -1,8 +1,5 @@
import argparse
import itertools
import json
import os
import re
import time
import torch
from safetensors.torch import load_file, save_file
@@ -11,195 +8,12 @@ from library import sai_model_spec, train_util
import library.model_util as model_util
import lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
CLAMP_QUANTILE = 0.99
ACCEPTABLE = [12, 17, 20, 26]
SDXL_LAYER_NUM = [12, 20]
LAYER12 = {
"BASE": True,
"IN00": False,
"IN01": False,
"IN02": False,
"IN03": False,
"IN04": True,
"IN05": True,
"IN06": False,
"IN07": True,
"IN08": True,
"IN09": False,
"IN10": False,
"IN11": False,
"MID": True,
"OUT00": True,
"OUT01": True,
"OUT02": True,
"OUT03": True,
"OUT04": True,
"OUT05": True,
"OUT06": False,
"OUT07": False,
"OUT08": False,
"OUT09": False,
"OUT10": False,
"OUT11": False,
}
LAYER17 = {
"BASE": True,
"IN00": False,
"IN01": True,
"IN02": True,
"IN03": False,
"IN04": True,
"IN05": True,
"IN06": False,
"IN07": True,
"IN08": True,
"IN09": False,
"IN10": False,
"IN11": False,
"MID": True,
"OUT00": False,
"OUT01": False,
"OUT02": False,
"OUT03": True,
"OUT04": True,
"OUT05": True,
"OUT06": True,
"OUT07": True,
"OUT08": True,
"OUT09": True,
"OUT10": True,
"OUT11": True,
}
LAYER20 = {
"BASE": True,
"IN00": True,
"IN01": True,
"IN02": True,
"IN03": True,
"IN04": True,
"IN05": True,
"IN06": True,
"IN07": True,
"IN08": True,
"IN09": False,
"IN10": False,
"IN11": False,
"MID": True,
"OUT00": True,
"OUT01": True,
"OUT02": True,
"OUT03": True,
"OUT04": True,
"OUT05": True,
"OUT06": True,
"OUT07": True,
"OUT08": True,
"OUT09": False,
"OUT10": False,
"OUT11": False,
}
LAYER26 = {
"BASE": True,
"IN00": True,
"IN01": True,
"IN02": True,
"IN03": True,
"IN04": True,
"IN05": True,
"IN06": True,
"IN07": True,
"IN08": True,
"IN09": True,
"IN10": True,
"IN11": True,
"MID": True,
"OUT00": True,
"OUT01": True,
"OUT02": True,
"OUT03": True,
"OUT04": True,
"OUT05": True,
"OUT06": True,
"OUT07": True,
"OUT08": True,
"OUT09": True,
"OUT10": True,
"OUT11": True,
}
assert len([v for v in LAYER12.values() if v]) == 12
assert len([v for v in LAYER17.values() if v]) == 17
assert len([v for v in LAYER20.values() if v]) == 20
assert len([v for v in LAYER26.values() if v]) == 26
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
def get_lbw_block_index(lora_name: str, is_sdxl: bool = False) -> int:
# lbw block index is 0-based, but 0 for text encoder, so we return 0 for text encoder
if "text_model_encoder_" in lora_name: # LoRA for text encoder
return 0
# lbw block index is 1-based for U-Net, and no "input_blocks.0" in CompVis SD, so "input_blocks.1" have index 2
block_idx = -1 # invalid lora name
if not is_sdxl:
NUM_OF_BLOCKS = 12 # up/down blocks
m = RE_UPDOWN.search(lora_name)
if m:
g = m.groups()
up_down = g[0]
i = int(g[1])
j = int(g[3])
if up_down == "down":
if g[2] == "resnets" or g[2] == "attentions":
idx = 3 * i + j + 1
elif g[2] == "downsamplers":
idx = 3 * (i + 1)
else:
return block_idx # invalid lora name
elif up_down == "up":
if g[2] == "resnets" or g[2] == "attentions":
idx = 3 * i + j
elif g[2] == "upsamplers":
idx = 3 * i + 2
else:
return block_idx # invalid lora name
if g[0] == "down":
block_idx = 1 + idx # 1-based index, down block index
elif g[0] == "up":
block_idx = 1 + NUM_OF_BLOCKS + 1 + idx # 1-based index, num blocks, mid block, up block index
elif "mid_block_" in lora_name:
block_idx = 1 + NUM_OF_BLOCKS # 1-based index, num blocks, mid block
else:
# SDXL: some numbers are skipped
if lora_name.startswith("lora_unet_"):
name = lora_name[len("lora_unet_") :]
if name.startswith("time_embed_") or name.startswith("label_emb_"): # 1, No LoRA in sd-scripts
block_idx = 1
elif name.startswith("input_blocks_"): # 1-8 to 2-9
block_idx = 1 + int(name.split("_")[2])
elif name.startswith("middle_block_"): # 13
block_idx = 13
elif name.startswith("output_blocks_"): # 0-8 to 14-22
block_idx = 14 + int(name.split("_")[2])
elif name.startswith("out_"): # 23, No LoRA in sd-scripts
block_idx = 23
return block_idx
def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == ".safetensors":
@@ -216,53 +30,24 @@ def load_state_dict(file_name, dtype):
return sd, metadata
def save_to_file(file_name, state_dict, metadata):
def save_to_file(file_name, state_dict, dtype, metadata):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == ".safetensors":
save_file(state_dict, file_name, metadata=metadata)
else:
torch.save(state_dict, file_name)
def format_lbws(lbws):
try:
# lbwは"[1,1,1,1,1,1,1,1,1,1,1,1]"のような文字列で与えられることを期待している
lbws = [json.loads(lbw) for lbw in lbws]
except Exception:
raise ValueError(f"format of lbws are must be json / 層別適用率はJSON形式で書いてください")
assert all(isinstance(lbw, list) for lbw in lbws), f"lbws are must be list / 層別適用率はリストにしてください"
assert len(set(len(lbw) for lbw in lbws)) == 1, "all lbws should have the same length / 層別適用率は同じ長さにしてください"
assert all(
len(lbw) in ACCEPTABLE for lbw in lbws
), f"length of lbw are must be in {ACCEPTABLE} / 層別適用率の長さは{ACCEPTABLE}のいずれかにしてください"
assert all(
all(isinstance(weight, (int, float)) for weight in lbw) for lbw in lbws
), f"values of lbs are must be numbers / 層別適用率の値はすべて数値にしてください"
layer_num = len(lbws[0])
is_sdxl = True if layer_num in SDXL_LAYER_NUM else False
FLAGS = {
"12": LAYER12.values(),
"17": LAYER17.values(),
"20": LAYER20.values(),
"26": LAYER26.values(),
}[str(layer_num)]
LBW_TARGET_IDX = [i for i, flag in enumerate(FLAGS) if flag]
return lbws, is_sdxl, LBW_TARGET_IDX
def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, merge_dtype):
def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
logger.info(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
merged_sd = {}
v2 = None # This is meaning LoRA Metadata v2, Not meaning SD2
v2 = None
base_model = None
if lbws:
lbws, is_sdxl, LBW_TARGET_IDX = format_lbws(lbws)
else:
is_sdxl = False
LBW_TARGET_IDX = []
for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws):
for model, ratio in zip(models, ratios):
logger.info(f"loading: {model}")
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
@@ -272,12 +57,6 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer
if base_model is None:
base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
if lbw:
lbw_weights = [1] * 26
for index, value in zip(LBW_TARGET_IDX, lbw):
lbw_weights[index] = value
logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}")
# merge
logger.info(f"merging...")
for key in tqdm(list(lora_sd.keys())):
@@ -301,10 +80,10 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer
# make original weight if not exist
if lora_module_name not in merged_sd:
weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
if device:
weight = weight.to(device)
else:
weight = merged_sd[lora_module_name]
if device:
weight = weight.to(device)
# merge to weight
if device:
@@ -314,12 +93,6 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer
# W <- W + U * D
scale = alpha / network_dim
if lbw:
index = get_lbw_block_index(key, is_sdxl)
is_lbw_target = index in LBW_TARGET_IDX
if is_lbw_target:
scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける
if device: # and isinstance(scale, torch.Tensor):
scale = scale.to(device)
@@ -336,16 +109,13 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
weight = weight + ratio * conved * scale
merged_sd[lora_module_name] = weight.to("cpu")
merged_sd[lora_module_name] = weight
# extract from merged weights
logger.info("extract new lora...")
merged_lora_sd = {}
with torch.no_grad():
for lora_module_name, mat in tqdm(list(merged_sd.items())):
if device:
mat = mat.to(device)
conv2d = len(mat.size()) == 4
kernel_size = None if not conv2d else mat.size()[2:4]
conv2d_3x3 = conv2d and kernel_size != (1, 1)
@@ -384,7 +154,7 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer
merged_lora_sd[lora_module_name + ".lora_up.weight"] = up_weight.to("cpu").contiguous()
merged_lora_sd[lora_module_name + ".lora_down.weight"] = down_weight.to("cpu").contiguous()
merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank, device="cpu")
merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank)
# build minimum metadata
dims = f"{new_rank}"
@@ -399,15 +169,7 @@ def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, mer
def merge(args):
assert len(args.models) == len(
args.ratios
), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
if args.lbws:
assert len(args.models) == len(
args.lbws
), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください"
else:
args.lbws = [] # zip_longestで扱えるようにlbws未使用時には空のリストにしておく
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
def str_to_dtype(p):
if p == "float":
@@ -425,15 +187,9 @@ def merge(args):
new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank
state_dict, metadata, v2, base_model = merge_lora_models(
args.models, args.ratios, args.lbws, args.new_rank, new_conv_rank, args.device, merge_dtype
args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype
)
# cast to save_dtype before calculating hashes
for key in list(state_dict.keys()):
value = state_dict[key]
if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype:
state_dict[key] = value.to(save_dtype)
logger.info(f"calculating hashes and creating metadata...")
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
@@ -455,7 +211,7 @@ def merge(args):
metadata.update(sai_metadata)
logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, metadata)
save_to_file(args.save_to, state_dict, save_dtype, metadata)
def setup_parser() -> argparse.ArgumentParser:
@@ -475,19 +231,12 @@ def setup_parser() -> argparse.ArgumentParser:
help="precision in merging (float is recommended) / マージの計算時の精度floatを推奨",
)
parser.add_argument(
"--save_to",
type=str,
default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
)
parser.add_argument(
"--models",
type=str,
nargs="*",
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors",
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
)
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
parser.add_argument("--lbws", type=str, nargs="*", help="lbw for each model / それぞれのLoRAモデルの層別適用率")
parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
parser.add_argument(
"--new_conv_rank",
@@ -495,9 +244,7 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ",
)
parser.add_argument(
"--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う"
)
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
parser.add_argument(
"--no_metadata",
action="store_true",

View File

@@ -1,14 +1,15 @@
accelerate==0.25.0
accelerate==0.30.0
transformers==4.36.2
diffusers[torch]==0.25.0
ftfy==6.1.1
# albumentations==1.3.0
opencv-python==4.8.1.78
opencv-python==4.7.0.68
einops==0.7.0
pytorch-lightning==1.9.0
bitsandbytes==0.43.0
prodigyopt==1.0
lion-pytorch==0.0.6
schedulefree==1.2.5
tensorboard
safetensors==0.4.2
# gradio==3.16.2

View File

@@ -407,6 +407,7 @@ def train(args):
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
if args.deepspeed:
ds_model = deepspeed_utils.prepare_deepspeed_model(
args,
@@ -415,9 +416,9 @@ def train(args):
text_encoder2=text_encoder2 if train_text_encoder2 else None,
)
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
ds_model, optimizer, train_dataloader = accelerator.prepare(ds_model, optimizer, train_dataloader)
if not use_schedule_free_optimizer:
lr_scheduler = accelerator.prepare(lr_scheduler)
training_models = [ds_model]
else:
@@ -428,7 +429,17 @@ def train(args):
text_encoder1 = accelerator.prepare(text_encoder1)
if train_text_encoder2:
text_encoder2 = accelerator.prepare(text_encoder2)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
if not use_schedule_free_optimizer:
lr_scheduler = accelerator.prepare(lr_scheduler)
optimizer, train_dataloader = accelerator.prepare(optimizer, train_dataloader)
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: optimizer.train()
optimizer_eval_if_needed = lambda: optimizer.eval()
else:
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
@@ -503,6 +514,7 @@ def train(args):
m.train()
for step, batch in enumerate(train_dataloader):
optimizer_train_if_needed()
current_step.value = global_step
with accelerator.accumulate(*training_models):
if "latents" in batch and batch["latents"] is not None:
@@ -582,7 +594,9 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
@@ -600,7 +614,9 @@ def train(args):
or args.masked_loss
):
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
@@ -616,7 +632,9 @@ def train(args):
loss = loss.mean() # mean over batch dimension
else:
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
)
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
@@ -626,9 +644,11 @@ def train(args):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
lr_scheduler.step() # if schedule-free optimizer is used, this is a no-op
optimizer.zero_grad(set_to_none=True)
optimizer_eval_if_needed()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
@@ -736,7 +756,7 @@ def train(args):
accelerator.end_training()
if args.save_state or args.save_state_on_train_end:
if args.save_state or args.save_state_on_train_end:
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す

View File

@@ -15,6 +15,7 @@ from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
@@ -286,7 +287,18 @@ def train(args):
unet.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
if not use_schedule_free_optimizer:
lr_scheduler = accelerator.prepare(lr_scheduler)
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: optimizer.train()
optimizer_eval_if_needed = lambda: optimizer.eval()
else:
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None
if args.gradient_checkpointing:
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
@@ -390,6 +402,7 @@ def train(args):
current_epoch.value = epoch + 1
for step, batch in enumerate(train_dataloader):
optimizer_train_if_needed()
current_step.value = global_step
with accelerator.accumulate(unet):
with torch.no_grad():
@@ -439,7 +452,9 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
@@ -458,7 +473,9 @@ def train(args):
else:
target = noise
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
@@ -484,6 +501,8 @@ def train(args):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
optimizer_eval_if_needed()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)

View File

@@ -12,6 +12,7 @@ from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
@@ -254,9 +255,19 @@ def train(args):
network.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, network, optimizer, train_dataloader, lr_scheduler
)
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
unet, network, optimizer, train_dataloader = accelerator.prepare(unet, network, optimizer, train_dataloader)
if not use_schedule_free_optimizer:
lr_scheduler = accelerator.prepare(lr_scheduler)
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: optimizer.train()
optimizer_eval_if_needed = lambda: optimizer.eval()
else:
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None
network: control_net_lllite.ControlNetLLLite
if args.gradient_checkpointing:
@@ -357,6 +368,7 @@ def train(args):
network.on_epoch_start() # train()
for step, batch in enumerate(train_dataloader):
optimizer_train_if_needed()
current_step.value = global_step
with accelerator.accumulate(network):
with torch.no_grad():
@@ -406,7 +418,9 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
@@ -426,7 +440,9 @@ def train(args):
else:
target = noise
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
@@ -452,6 +468,8 @@ def train(args):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
optimizer_eval_if_needed()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)

View File

@@ -16,13 +16,12 @@ from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
from library.utils import setup_logging, add_logging_arguments
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def cache_to_disk(args: argparse.Namespace) -> None:
setup_logging(args, reset=True)
train_util.prepare_dataset_args(args, True)
# check cache latents arg
@@ -95,7 +94,6 @@ def cache_to_disk(args: argparse.Namespace) -> None:
# acceleratorを準備する
logger.info("prepare accelerator")
args.deepspeed = False
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
@@ -172,7 +170,6 @@ def cache_to_disk(args: argparse.Namespace) -> None:
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_training_arguments(parser, True)
train_util.add_dataset_arguments(parser, True, True, True)

View File

@@ -16,13 +16,12 @@ from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
from library.utils import setup_logging, add_logging_arguments
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def cache_to_disk(args: argparse.Namespace) -> None:
setup_logging(args, reset=True)
train_util.prepare_dataset_args(args, True)
# check cache arg
@@ -100,7 +99,6 @@ def cache_to_disk(args: argparse.Namespace) -> None:
# acceleratorを準備する
logger.info("prepare accelerator")
args.deepspeed = False
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
@@ -173,7 +171,6 @@ def cache_to_disk(args: argparse.Namespace) -> None:
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_training_arguments(parser, True)
train_util.add_dataset_arguments(parser, True, True, True)

View File

@@ -13,6 +13,7 @@ from tqdm import tqdm
import torch
from library import deepspeed_utils
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
@@ -226,7 +227,7 @@ def train(args):
)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
if args.gradient_checkpointing:
@@ -276,9 +277,18 @@ def train(args):
controlnet.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
controlnet, optimizer, train_dataloader, lr_scheduler
)
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
controlnet, optimizer, train_dataloader = accelerator.prepare(controlnet, optimizer, train_dataloader)
if not use_schedule_free_optimizer:
lr_scheduler = accelerator.prepare(lr_scheduler)
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: optimizer.train()
optimizer_eval_if_needed = lambda: optimizer.eval()
else:
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None
unet.requires_grad_(False)
text_encoder.requires_grad_(False)
@@ -393,6 +403,7 @@ def train(args):
current_epoch.value = epoch + 1
for step, batch in enumerate(train_dataloader):
optimizer_train_if_needed()
current_step.value = global_step
with accelerator.accumulate(controlnet):
with torch.no_grad():
@@ -420,7 +431,9 @@ def train(args):
)
# Sample a random timestep for each image
timesteps, huber_c = train_util.get_timesteps_and_huber_c(args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device)
timesteps, huber_c = train_util.get_timesteps_and_huber_c(
args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device
)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
@@ -452,7 +465,9 @@ def train(args):
else:
target = noise
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
@@ -472,6 +487,8 @@ def train(args):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
optimizer_eval_if_needed()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)

View File

@@ -224,25 +224,34 @@ def train(args):
text_encoder.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
if args.deepspeed:
if args.train_text_encoder:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
else:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
ds_model, optimizer, train_dataloader = accelerator.prepare(ds_model, optimizer, train_dataloader)
if not use_schedule_free_optimizer:
lr_scheduler = accelerator.prepare(lr_scheduler)
training_models = [ds_model]
else:
if train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(unet, text_encoder, optimizer, train_dataloader)
training_models = [unet, text_encoder]
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
training_models = [unet]
if not use_schedule_free_optimizer:
lr_scheduler = accelerator.prepare(lr_scheduler)
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: optimizer.train()
optimizer_eval_if_needed = lambda: optimizer.eval()
else:
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None
if not train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
@@ -307,6 +316,7 @@ def train(args):
text_encoder.train()
for step, batch in enumerate(train_dataloader):
optimizer_train_if_needed()
current_step.value = global_step
# 指定したステップ数でText Encoderの学習を止める
if global_step == args.stop_text_encoder_training:
@@ -346,7 +356,9 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)
# Predict the noise residual
with accelerator.autocast():
@@ -358,7 +370,9 @@ def train(args):
else:
target = noise
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
@@ -387,6 +401,8 @@ def train(args):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
optimizer_eval_if_needed()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)

View File

@@ -412,6 +412,7 @@ class NetworkTrainer:
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
if args.deepspeed:
ds_model = deepspeed_utils.prepare_deepspeed_model(
args,
@@ -420,9 +421,9 @@ class NetworkTrainer:
text_encoder2=text_encoders[1] if train_text_encoder and len(text_encoders) > 1 else None,
network=network,
)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
ds_model, optimizer, train_dataloader = accelerator.prepare(ds_model, optimizer, train_dataloader)
if not use_schedule_free_optimizer:
lr_scheduler = accelerator.prepare(lr_scheduler)
training_model = ds_model
else:
if train_unet:
@@ -438,14 +439,23 @@ class NetworkTrainer:
else:
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
network, optimizer, train_dataloader, lr_scheduler
)
network, optimizer, train_dataloader = accelerator.prepare(network, optimizer, train_dataloader)
if not use_schedule_free_optimizer:
lr_scheduler = accelerator.prepare(lr_scheduler)
training_model = network
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train()
optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval()
else:
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None
if args.gradient_checkpointing:
# according to TI example in Diffusers, train is required
unet.train()
for t_enc in text_encoders:
t_enc.train()
@@ -474,13 +484,15 @@ class NetworkTrainer:
# before resuming make hook for saving/loading to save/load the network weights only
def save_model_hook(models, weights, output_dir):
# pop weights of other models than network to save only network weights
if accelerator.is_main_process:
# only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606
if accelerator.is_main_process or args.deepspeed:
remove_indices = []
for i, model in enumerate(models):
if not isinstance(model, type(accelerator.unwrap_model(network))):
remove_indices.append(i)
for i in reversed(remove_indices):
weights.pop(i)
if len(weights) > i:
weights.pop(i)
# print(f"save model hook: {len(weights)} weights will be saved")
def load_model_hook(models, input_dir):
@@ -802,6 +814,7 @@ class NetworkTrainer:
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
for step, batch in enumerate(train_dataloader):
optimizer_train_if_needed()
current_step.value = global_step
with accelerator.accumulate(training_model):
on_step_start(text_encoder, unet)
@@ -918,6 +931,8 @@ class NetworkTrainer:
else:
keys_scaled, mean_norm, maximum_norm = None, None, None
optimizer_eval_if_needed()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)

View File

@@ -415,20 +415,28 @@ class TextualInversionTrainer:
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# acceleratorがなんかよろしくやってくれるらしい
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
if len(text_encoders) == 1:
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler
text_encoder_or_list, optimizer, train_dataloader = accelerator.preparet(
text_encoder_or_list, optimizer, train_dataloader
)
elif len(text_encoders) == 2:
text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler
text_encoder1, text_encoder2, optimizer, train_dataloader = accelerator.prepare(
text_encoders[0], text_encoders[1], optimizer, train_dataloader
)
text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2]
else:
raise NotImplementedError()
if not use_schedule_free_optimizer:
optimizer, lr_scheduler = accelerator.prepare(optimizer, lr_scheduler)
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train()
optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval()
else:
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None
index_no_updates_list = []
orig_embeds_params_list = []
@@ -557,6 +565,7 @@ class TextualInversionTrainer:
loss_total = 0
for step, batch in enumerate(train_dataloader):
optimizer_train_if_needed()
current_step.value = global_step
with accelerator.accumulate(text_encoders[0]):
with torch.no_grad():
@@ -588,7 +597,9 @@ class TextualInversionTrainer:
else:
target = noise
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
@@ -627,6 +638,8 @@ class TextualInversionTrainer:
index_no_updates
]
optimizer_eval_if_needed()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)

View File

@@ -335,9 +335,18 @@ def train(args):
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# acceleratorがなんかよろしくやってくれるらしい
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder, optimizer, train_dataloader, lr_scheduler
)
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
text_encoder, optimizer, train_dataloader = accelerator.prepare(text_encoder, optimizer, train_dataloader)
if not use_schedule_free_optimizer:
lr_scheduler = accelerator.prepare(lr_scheduler)
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: optimizer.train()
optimizer_eval_if_needed = lambda: optimizer.eval()
else:
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None
index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
# logger.info(len(index_no_updates), torch.sum(index_no_updates))
@@ -438,6 +447,7 @@ def train(args):
loss_total = 0
for step, batch in enumerate(train_dataloader):
optimizer_train_if_needed()
current_step.value = global_step
with accelerator.accumulate(text_encoder):
with torch.no_grad():
@@ -461,7 +471,9 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)
# Predict the noise residual
with accelerator.autocast():
@@ -473,7 +485,9 @@ def train(args):
else:
target = noise
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
@@ -505,6 +519,8 @@ def train(args):
index_no_updates
]
optimizer_eval_if_needed()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)