層別学習率を使わない場合にparamsをまとめる

This commit is contained in:
u-haru
2023-04-02 12:57:55 +09:00
parent 058e442072
commit 19340d82e6

View File

@@ -514,17 +514,25 @@ class LoRANetwork(torch.nn.Module):
all_params.append(param_data)
if self.unet_loras:
for lora in self.unet_loras:
param_data = {'params': lora.parameters()}
if self.stratified_lr:
for lora in self.unet_loras:
param_data = {'params': lora.parameters()}
if unet_lr is not None:
param_data['lr'] = unet_lr * self.get_stratified_lr_weight(lora)
elif default_lr is not None:
param_data['lr'] = default_lr * self.get_stratified_lr_weight(lora)
if ('lr' in param_data) and (param_data['lr']==0):
continue
all_params.append(param_data)
else:
params = []
for lora in self.unet_loras:
params.extend(lora.parameters())
param_data = {'params': params}
if unet_lr is not None:
param_data['lr'] = unet_lr
elif default_lr is not None:
param_data['lr'] = default_lr
if self.stratified_lr and ('lr' in param_data):
param_data['lr'] = param_data['lr'] * self.get_stratified_lr_weight(lora)
if (param_data['lr']==0):
continue
all_params.append(param_data)
return all_params
def enable_gradient_checkpointing(self):