mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
差分学習機能追加
This commit is contained in:
@@ -148,6 +148,20 @@ def train(args):
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||
|
||||
# prepare network
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
print("import network module:", args.network_module)
|
||||
network_module = importlib.import_module(args.network_module)
|
||||
|
||||
if args.base_modules is not None:
|
||||
for module_path in args.base_modules:
|
||||
print("merging module: %s"%module_path)
|
||||
module, weights_sd = network_module.create_network_from_weights(args.base_modules_weight, module_path, vae, text_encoder, unet, for_inference=True)
|
||||
module.merge_to(text_encoder, unet, weights_sd, weight_dtype, "cpu")
|
||||
print("all modules merged: %s"%", ".join(args.base_modules))
|
||||
|
||||
# 学習を準備する
|
||||
if cache_latents:
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
@@ -162,13 +176,6 @@ def train(args):
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# prepare network
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
print("import network module:", args.network_module)
|
||||
network_module = importlib.import_module(args.network_module)
|
||||
|
||||
net_kwargs = {}
|
||||
if args.network_args is not None:
|
||||
for net_arg in args.network_args:
|
||||
@@ -770,6 +777,16 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
action="store_true",
|
||||
help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base_modules",
|
||||
type=str, default=None, nargs="*",
|
||||
help="base modules for differential learning / 差分学習用のベースモデル",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base_modules_weight",
|
||||
type=float, default=1,
|
||||
help="weight of base modules for differential learning / 差分学習用のベースモデルの比重",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
Reference in New Issue
Block a user