mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
層別学習率を使わない場合にparamsをまとめる
This commit is contained in:
@@ -514,17 +514,25 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
all_params.append(param_data)
|
all_params.append(param_data)
|
||||||
|
|
||||||
if self.unet_loras:
|
if self.unet_loras:
|
||||||
|
if self.stratified_lr:
|
||||||
for lora in self.unet_loras:
|
for lora in self.unet_loras:
|
||||||
param_data = {'params': lora.parameters()}
|
param_data = {'params': lora.parameters()}
|
||||||
if unet_lr is not None:
|
if unet_lr is not None:
|
||||||
param_data['lr'] = unet_lr
|
param_data['lr'] = unet_lr * self.get_stratified_lr_weight(lora)
|
||||||
elif default_lr is not None:
|
elif default_lr is not None:
|
||||||
param_data['lr'] = default_lr
|
param_data['lr'] = default_lr * self.get_stratified_lr_weight(lora)
|
||||||
if self.stratified_lr and ('lr' in param_data):
|
if ('lr' in param_data) and (param_data['lr']==0):
|
||||||
param_data['lr'] = param_data['lr'] * self.get_stratified_lr_weight(lora)
|
|
||||||
if (param_data['lr']==0):
|
|
||||||
continue
|
continue
|
||||||
all_params.append(param_data)
|
all_params.append(param_data)
|
||||||
|
else:
|
||||||
|
params = []
|
||||||
|
for lora in self.unet_loras:
|
||||||
|
params.extend(lora.parameters())
|
||||||
|
param_data = {'params': params}
|
||||||
|
if unet_lr is not None:
|
||||||
|
param_data['lr'] = unet_lr
|
||||||
|
all_params.append(param_data)
|
||||||
|
|
||||||
return all_params
|
return all_params
|
||||||
|
|
||||||
def enable_gradient_checkpointing(self):
|
def enable_gradient_checkpointing(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user