adding torchao and fixes

This commit is contained in:
Darren Laurie
2025-03-17 01:37:23 +08:00
parent 4a3ced5fb9
commit d35c51a59e
3 changed files with 34 additions and 4 deletions

View File

@@ -708,21 +708,26 @@ class NativeTrainer:
unet.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared
# TODO: SDXL Model Specific
# TODO: Is casting to torch.tensor slowing down the performance so much? (20% slower)
training_models = []
params_to_optimize = []
using_torchao = args.optimizer_type.endswith("4bit") or args.optimizer_type.endswith("Fp8")
if train_unet:
training_models.append(unet)
if block_lrs is None:
params_to_optimize.append({"params": list(unet.parameters()), "lr": args.learning_rate})
lr_unet = args.learning_rate
params_to_optimize.append({"params": list(unet.parameters()), "lr": torch.tensor(lr_unet) if using_torchao else lr_unet})
else:
params_to_optimize.extend(self.get_block_params_to_optimize(unet, block_lrs))
if train_text_encoder1:
training_models.append(text_encoder1)
params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": args.learning_rate_te1 or args.learning_rate})
lr_te1 = args.learning_rate_te1 or args.learning_rate
params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": torch.tensor(lr_te1) if using_torchao else lr_te1})
if train_text_encoder2:
training_models.append(text_encoder2)
params_to_optimize.append({"params": list(text_encoder2.parameters()), "lr": args.learning_rate_te2 or args.learning_rate})
lr_te2 = args.learning_rate_te2 or args.learning_rate
params_to_optimize.append({"params": list(text_encoder2.parameters()), "lr": torch.tensor(lr_te2) if using_torchao else lr_te2})
# calculate number of trainable parameters
n_params = 0