mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
support older type hint
This commit is contained in:
@@ -5,7 +5,7 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from typing import List
|
from typing import List, Union
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import re
|
import re
|
||||||
@@ -247,7 +247,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
|||||||
|
|
||||||
|
|
||||||
class LoRANetwork(torch.nn.Module):
|
class LoRANetwork(torch.nn.Module):
|
||||||
# is it possible to apply conv_in and conv_out?
|
# is it possible to apply conv_in and conv_out? -> yes, newer LoCon supports it (^^;)
|
||||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
|
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
|
||||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||||
@@ -333,8 +333,8 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
|
|
||||||
self.weights_sd = None
|
self.weights_sd = None
|
||||||
|
|
||||||
self.up_lr_weight: list[float] = None
|
self.up_lr_weight: List[float] = None
|
||||||
self.down_lr_weight: list[float] = None
|
self.down_lr_weight: List[float] = None
|
||||||
self.mid_lr_weight: float = None
|
self.mid_lr_weight: float = None
|
||||||
self.block_lr = False
|
self.block_lr = False
|
||||||
|
|
||||||
@@ -445,9 +445,9 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
# 層別学習率用に層ごとの学習率に対する倍率を定義する
|
# 層別学習率用に層ごとの学習率に対する倍率を定義する
|
||||||
def set_block_lr_weight(
|
def set_block_lr_weight(
|
||||||
self,
|
self,
|
||||||
up_lr_weight: list[float] | str = None,
|
up_lr_weight: Union[List[float], str] = None,
|
||||||
mid_lr_weight: float = None,
|
mid_lr_weight: float = None,
|
||||||
down_lr_weight: list[float] | str = None,
|
down_lr_weight: Union[List[float], str] = None,
|
||||||
zero_threshold: float = 0.0,
|
zero_threshold: float = 0.0,
|
||||||
):
|
):
|
||||||
# バラメータ未指定時は何もせず、今までと同じ動作とする
|
# バラメータ未指定時は何もせず、今までと同じ動作とする
|
||||||
@@ -456,7 +456,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
|
|
||||||
max_len = 12 # フルモデル相当でのup,downの層の数
|
max_len = 12 # フルモデル相当でのup,downの層の数
|
||||||
|
|
||||||
def get_list(name) -> list[float]:
|
def get_list(name) -> List[float]:
|
||||||
import math
|
import math
|
||||||
|
|
||||||
if name == "cosine":
|
if name == "cosine":
|
||||||
|
|||||||
Reference in New Issue
Block a user