mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
層別学習率を使わない場合にparamsをまとめる
This commit is contained in:
@@ -514,17 +514,25 @@ class LoRANetwork(torch.nn.Module):
|
||||
all_params.append(param_data)
|
||||
|
||||
if self.unet_loras:
|
||||
if self.stratified_lr:
|
||||
for lora in self.unet_loras:
|
||||
param_data = {'params': lora.parameters()}
|
||||
if unet_lr is not None:
|
||||
param_data['lr'] = unet_lr
|
||||
param_data['lr'] = unet_lr * self.get_stratified_lr_weight(lora)
|
||||
elif default_lr is not None:
|
||||
param_data['lr'] = default_lr
|
||||
if self.stratified_lr and ('lr' in param_data):
|
||||
param_data['lr'] = param_data['lr'] * self.get_stratified_lr_weight(lora)
|
||||
if (param_data['lr']==0):
|
||||
param_data['lr'] = default_lr * self.get_stratified_lr_weight(lora)
|
||||
if ('lr' in param_data) and (param_data['lr']==0):
|
||||
continue
|
||||
all_params.append(param_data)
|
||||
else:
|
||||
params = []
|
||||
for lora in self.unet_loras:
|
||||
params.extend(lora.parameters())
|
||||
param_data = {'params': params}
|
||||
if unet_lr is not None:
|
||||
param_data['lr'] = unet_lr
|
||||
all_params.append(param_data)
|
||||
|
||||
return all_params
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
|
||||
Reference in New Issue
Block a user