update loraplus on dylora/lofa_fa

This commit is contained in:
Kohya S
2024-05-06 11:09:32 +09:00
parent 52e64c69cf
commit 7fe81502d0
3 changed files with 71 additions and 34 deletions

View File

@@ -499,7 +499,8 @@ def create_network(
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
if block_lr_weight is not None:
network.set_block_lr_weight(block_lr_weight)
@@ -855,6 +856,10 @@ class LoRANetwork(torch.nn.Module):
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
self.loraplus_text_encoder_lr_ratio = None
if modules_dim is not None:
logger.info(f"create LoRA network from weights")
elif block_dims is not None: