support older type hint

This commit is contained in:
Kohya S
2023-04-02 16:18:04 +09:00
parent 97e65bf93f
commit c639cb7d5d

View File

@@ -5,7 +5,7 @@
import math
import os
from typing import List
from typing import List, Union
import numpy as np
import torch
import re
@@ -247,7 +247,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
class LoRANetwork(torch.nn.Module):
# is it possible to apply conv_in and conv_out?
# is it possible to apply conv_in and conv_out? -> yes, newer LoCon supports it (^^;)
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
@@ -333,8 +333,8 @@ class LoRANetwork(torch.nn.Module):
self.weights_sd = None
self.up_lr_weight: list[float] = None
self.down_lr_weight: list[float] = None
self.up_lr_weight: List[float] = None
self.down_lr_weight: List[float] = None
self.mid_lr_weight: float = None
self.block_lr = False
@@ -445,9 +445,9 @@ class LoRANetwork(torch.nn.Module):
# 層別学習率用に層ごとの学習率に対する倍率を定義する
def set_block_lr_weight(
self,
up_lr_weight: list[float] | str = None,
up_lr_weight: Union[List[float], str] = None,
mid_lr_weight: float = None,
down_lr_weight: list[float] | str = None,
down_lr_weight: Union[List[float], str] = None,
zero_threshold: float = 0.0,
):
# バラメータ未指定時は何もせず、今までと同じ動作とする
@@ -456,7 +456,7 @@ class LoRANetwork(torch.nn.Module):
max_len = 12 # フルモデル相当でのup,downの層の数
def get_list(name) -> list[float]:
def get_list(name) -> List[float]:
import math
if name == "cosine":