support older type hint

This commit is contained in:
Kohya S
2023-04-02 16:18:04 +09:00
parent 97e65bf93f
commit c639cb7d5d

View File

@@ -5,7 +5,7 @@
import math import math
import os import os
from typing import List from typing import List, Union
import numpy as np import numpy as np
import torch import torch
import re import re
@@ -247,7 +247,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
class LoRANetwork(torch.nn.Module): class LoRANetwork(torch.nn.Module):
# is it possible to apply conv_in and conv_out? # is it possible to apply conv_in and conv_out? -> yes, newer LoCon supports it (^^;)
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
@@ -333,8 +333,8 @@ class LoRANetwork(torch.nn.Module):
self.weights_sd = None self.weights_sd = None
self.up_lr_weight: list[float] = None self.up_lr_weight: List[float] = None
self.down_lr_weight: list[float] = None self.down_lr_weight: List[float] = None
self.mid_lr_weight: float = None self.mid_lr_weight: float = None
self.block_lr = False self.block_lr = False
@@ -445,9 +445,9 @@ class LoRANetwork(torch.nn.Module):
# 層別学習率用に層ごとの学習率に対する倍率を定義する # 層別学習率用に層ごとの学習率に対する倍率を定義する
def set_block_lr_weight( def set_block_lr_weight(
self, self,
up_lr_weight: list[float] | str = None, up_lr_weight: Union[List[float], str] = None,
mid_lr_weight: float = None, mid_lr_weight: float = None,
down_lr_weight: list[float] | str = None, down_lr_weight: Union[List[float], str] = None,
zero_threshold: float = 0.0, zero_threshold: float = 0.0,
): ):
# バラメータ未指定時は何もせず、今までと同じ動作とする # バラメータ未指定時は何もせず、今までと同じ動作とする
@@ -456,7 +456,7 @@ class LoRANetwork(torch.nn.Module):
max_len = 12 # フルモデル相当でのup,downの層の数 max_len = 12 # フルモデル相当でのup,downの層の数
def get_list(name) -> list[float]: def get_list(name) -> List[float]:
import math import math
if name == "cosine": if name == "cosine":