support sai model spec

This commit is contained in:
Kohya S
2023-08-06 21:50:05 +09:00
parent cd54af019a
commit c142dadb46
15 changed files with 746 additions and 64 deletions

View File

@@ -156,10 +156,11 @@ def train(args):
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
if train_text_encoder:
trainable_params = itertools.chain(unet.parameters(), text_encoder.parameters())
# wightout list, adamw8bit is crashed
trainable_params = list(itertools.chain(unet.parameters(), text_encoder.parameters()))
else:
trainable_params = unet.parameters()
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する