mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 17:24:21 +00:00
adding torchao and fixes
This commit is contained in:
@@ -708,21 +708,26 @@ class NativeTrainer:
|
||||
unet.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared
|
||||
|
||||
# TODO: SDXL Model Specific
|
||||
# TODO: Is casting to torch.tensor slowing down the performance so much? (20% slower)
|
||||
training_models = []
|
||||
params_to_optimize = []
|
||||
using_torchao = args.optimizer_type.endswith("4bit") or args.optimizer_type.endswith("Fp8")
|
||||
if train_unet:
|
||||
training_models.append(unet)
|
||||
if block_lrs is None:
|
||||
params_to_optimize.append({"params": list(unet.parameters()), "lr": args.learning_rate})
|
||||
lr_unet = args.learning_rate
|
||||
params_to_optimize.append({"params": list(unet.parameters()), "lr": torch.tensor(lr_unet) if using_torchao else lr_unet})
|
||||
else:
|
||||
params_to_optimize.extend(self.get_block_params_to_optimize(unet, block_lrs))
|
||||
|
||||
if train_text_encoder1:
|
||||
training_models.append(text_encoder1)
|
||||
params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": args.learning_rate_te1 or args.learning_rate})
|
||||
lr_te1 = args.learning_rate_te1 or args.learning_rate
|
||||
params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": torch.tensor(lr_te1) if using_torchao else lr_te1})
|
||||
if train_text_encoder2:
|
||||
training_models.append(text_encoder2)
|
||||
params_to_optimize.append({"params": list(text_encoder2.parameters()), "lr": args.learning_rate_te2 or args.learning_rate})
|
||||
lr_te2 = args.learning_rate_te2 or args.learning_rate
|
||||
params_to_optimize.append({"params": list(text_encoder2.parameters()), "lr": torch.tensor(lr_te2) if using_torchao else lr_te2})
|
||||
|
||||
# calculate number of trainable parameters
|
||||
n_params = 0
|
||||
|
||||
Reference in New Issue
Block a user