mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Fix default_lr being applied
This commit is contained in:
@@ -407,7 +407,14 @@ class DyLoRANetwork(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
||||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None):
|
def prepare_optimizer_params(
|
||||||
|
self,
|
||||||
|
text_encoder_lr,
|
||||||
|
unet_lr,
|
||||||
|
default_lr,
|
||||||
|
unet_lora_plus_ratio=None,
|
||||||
|
text_encoder_lora_plus_ratio=None
|
||||||
|
):
|
||||||
self.requires_grad_(True)
|
self.requires_grad_(True)
|
||||||
all_params = []
|
all_params = []
|
||||||
|
|
||||||
@@ -442,11 +449,19 @@ class DyLoRANetwork(torch.nn.Module):
|
|||||||
return params
|
return params
|
||||||
|
|
||||||
if self.text_encoder_loras:
|
if self.text_encoder_loras:
|
||||||
params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio)
|
params = assemble_params(
|
||||||
|
self.text_encoder_loras,
|
||||||
|
text_encoder_lr if text_encoder_lr is not None else default_lr,
|
||||||
|
text_encoder_lora_plus_ratio
|
||||||
|
)
|
||||||
all_params.extend(params)
|
all_params.extend(params)
|
||||||
|
|
||||||
if self.unet_loras:
|
if self.unet_loras:
|
||||||
params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio)
|
params = assemble_params(
|
||||||
|
self.unet_loras,
|
||||||
|
default_lr if unet_lr is None else unet_lr,
|
||||||
|
unet_lora_plus_ratio
|
||||||
|
)
|
||||||
all_params.extend(params)
|
all_params.extend(params)
|
||||||
|
|
||||||
return all_params
|
return all_params
|
||||||
|
|||||||
@@ -1035,7 +1035,14 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
return lr_weight
|
return lr_weight
|
||||||
|
|
||||||
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
||||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None):
|
def prepare_optimizer_params(
|
||||||
|
self,
|
||||||
|
text_encoder_lr,
|
||||||
|
unet_lr,
|
||||||
|
default_lr,
|
||||||
|
unet_lora_plus_ratio=None,
|
||||||
|
text_encoder_lora_plus_ratio=None
|
||||||
|
):
|
||||||
self.requires_grad_(True)
|
self.requires_grad_(True)
|
||||||
all_params = []
|
all_params = []
|
||||||
|
|
||||||
@@ -1070,7 +1077,11 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
return params
|
return params
|
||||||
|
|
||||||
if self.text_encoder_loras:
|
if self.text_encoder_loras:
|
||||||
params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio)
|
params = assemble_params(
|
||||||
|
self.text_encoder_loras,
|
||||||
|
text_encoder_lr if text_encoder_lr is not None else default_lr,
|
||||||
|
text_encoder_lora_plus_ratio
|
||||||
|
)
|
||||||
all_params.extend(params)
|
all_params.extend(params)
|
||||||
|
|
||||||
if self.unet_loras:
|
if self.unet_loras:
|
||||||
@@ -1085,14 +1096,19 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
|
|
||||||
# blockごとにパラメータを設定する
|
# blockごとにパラメータを設定する
|
||||||
for idx, block_loras in block_idx_to_lora.items():
|
for idx, block_loras in block_idx_to_lora.items():
|
||||||
if unet_lr is not None:
|
params = assemble_params(
|
||||||
params = assemble_params(block_loras, unet_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio)
|
block_loras,
|
||||||
elif default_lr is not None:
|
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]),
|
||||||
params = assemble_params(block_loras, default_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio)
|
unet_lora_plus_ratio
|
||||||
|
)
|
||||||
all_params.extend(params)
|
all_params.extend(params)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio)
|
params = assemble_params(
|
||||||
|
self.unet_loras,
|
||||||
|
default_lr if unet_lr is None else unet_lr,
|
||||||
|
unet_lora_plus_ratio
|
||||||
|
)
|
||||||
all_params.extend(params)
|
all_params.extend(params)
|
||||||
|
|
||||||
return all_params
|
return all_params
|
||||||
|
|||||||
@@ -1033,7 +1033,14 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
return lr_weight
|
return lr_weight
|
||||||
|
|
||||||
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
||||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, , unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None):
|
def prepare_optimizer_params(
|
||||||
|
self,
|
||||||
|
text_encoder_lr,
|
||||||
|
unet_lr,
|
||||||
|
default_lr,
|
||||||
|
unet_lora_plus_ratio=None,
|
||||||
|
text_encoder_lora_plus_ratio=None
|
||||||
|
):
|
||||||
self.requires_grad_(True)
|
self.requires_grad_(True)
|
||||||
all_params = []
|
all_params = []
|
||||||
|
|
||||||
@@ -1068,7 +1075,11 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
return params
|
return params
|
||||||
|
|
||||||
if self.text_encoder_loras:
|
if self.text_encoder_loras:
|
||||||
params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio)
|
params = assemble_params(
|
||||||
|
self.text_encoder_loras,
|
||||||
|
text_encoder_lr if text_encoder_lr is not None else default_lr,
|
||||||
|
text_encoder_lora_plus_ratio
|
||||||
|
)
|
||||||
all_params.extend(params)
|
all_params.extend(params)
|
||||||
|
|
||||||
if self.unet_loras:
|
if self.unet_loras:
|
||||||
@@ -1083,14 +1094,19 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
|
|
||||||
# blockごとにパラメータを設定する
|
# blockごとにパラメータを設定する
|
||||||
for idx, block_loras in block_idx_to_lora.items():
|
for idx, block_loras in block_idx_to_lora.items():
|
||||||
if unet_lr is not None:
|
params = assemble_params(
|
||||||
params = assemble_params(block_loras, unet_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio)
|
block_loras,
|
||||||
elif default_lr is not None:
|
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]),
|
||||||
params = assemble_params(block_loras, default_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio)
|
unet_lora_plus_ratio
|
||||||
|
)
|
||||||
all_params.extend(params)
|
all_params.extend(params)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio)
|
params = assemble_params(
|
||||||
|
self.unet_loras,
|
||||||
|
default_lr if unet_lr is None else unet_lr,
|
||||||
|
unet_lora_plus_ratio
|
||||||
|
)
|
||||||
all_params.extend(params)
|
all_params.extend(params)
|
||||||
|
|
||||||
return all_params
|
return all_params
|
||||||
|
|||||||
Reference in New Issue
Block a user