mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
差分学習機能追加
This commit is contained in:
@@ -148,6 +148,20 @@ def train(args):
|
|||||||
# モデルに xformers とか memory efficient attention を組み込む
|
# モデルに xformers とか memory efficient attention を組み込む
|
||||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||||
|
|
||||||
|
# prepare network
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.append(os.path.dirname(__file__))
|
||||||
|
print("import network module:", args.network_module)
|
||||||
|
network_module = importlib.import_module(args.network_module)
|
||||||
|
|
||||||
|
if args.base_modules is not None:
|
||||||
|
for module_path in args.base_modules:
|
||||||
|
print("merging module: %s"%module_path)
|
||||||
|
module, weights_sd = network_module.create_network_from_weights(args.base_modules_weight, module_path, vae, text_encoder, unet, for_inference=True)
|
||||||
|
module.merge_to(text_encoder, unet, weights_sd, weight_dtype, "cpu")
|
||||||
|
print("all modules merged: %s"%", ".join(args.base_modules))
|
||||||
|
|
||||||
# 学習を準備する
|
# 学習を準備する
|
||||||
if cache_latents:
|
if cache_latents:
|
||||||
vae.to(accelerator.device, dtype=weight_dtype)
|
vae.to(accelerator.device, dtype=weight_dtype)
|
||||||
@@ -162,13 +176,6 @@ def train(args):
|
|||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
# prepare network
|
|
||||||
import sys
|
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(__file__))
|
|
||||||
print("import network module:", args.network_module)
|
|
||||||
network_module = importlib.import_module(args.network_module)
|
|
||||||
|
|
||||||
net_kwargs = {}
|
net_kwargs = {}
|
||||||
if args.network_args is not None:
|
if args.network_args is not None:
|
||||||
for net_arg in args.network_args:
|
for net_arg in args.network_args:
|
||||||
@@ -770,6 +777,16 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する",
|
help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--base_modules",
|
||||||
|
type=str, default=None, nargs="*",
|
||||||
|
help="base modules for differential learning / 差分学習用のベースモデル",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--base_modules_weight",
|
||||||
|
type=float, default=1,
|
||||||
|
help="weight of base modules for differential learning / 差分学習用のベースモデルの比重",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user