format by black

This commit is contained in:
Kohya S
2023-03-29 21:23:27 +09:00
parent b996f5a6d6
commit bf3674c1db

View File

@@ -13,114 +13,114 @@ from library import train_util
class LoRAModule(torch.nn.Module): class LoRAModule(torch.nn.Module):
""" """
replaces forward method of the original Linear, instead of replacing the original Linear module. replaces forward method of the original Linear, instead of replacing the original Linear module.
""" """
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1): def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
""" if alpha == 0 or None, alpha is rank (no scaling). """ """if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__() super().__init__()
self.lora_name = lora_name self.lora_name = lora_name
if org_module.__class__.__name__ == 'Conv2d': if org_module.__class__.__name__ == "Conv2d":
in_dim = org_module.in_channels in_dim = org_module.in_channels
out_dim = org_module.out_channels out_dim = org_module.out_channels
else: else:
in_dim = org_module.in_features in_dim = org_module.in_features
out_dim = org_module.out_features out_dim = org_module.out_features
# if limit_rank: # if limit_rank:
# self.lora_dim = min(lora_dim, in_dim, out_dim) # self.lora_dim = min(lora_dim, in_dim, out_dim)
# if self.lora_dim != lora_dim: # if self.lora_dim != lora_dim:
# print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
# else: # else:
self.lora_dim = lora_dim self.lora_dim = lora_dim
if org_module.__class__.__name__ == 'Conv2d': if org_module.__class__.__name__ == "Conv2d":
kernel_size = org_module.kernel_size kernel_size = org_module.kernel_size
stride = org_module.stride stride = org_module.stride
padding = org_module.padding padding = org_module.padding
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
else: else:
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
if type(alpha) == torch.Tensor: if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim self.scale = alpha / self.lora_dim
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
# same as microsoft's # same as microsoft's
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
torch.nn.init.zeros_(self.lora_up.weight) torch.nn.init.zeros_(self.lora_up.weight)
self.multiplier = multiplier self.multiplier = multiplier
self.org_module = org_module # remove in applying self.org_module = org_module # remove in applying
self.region = None self.region = None
self.region_mask = None self.region_mask = None
def apply_to(self): def apply_to(self):
self.org_forward = self.org_module.forward self.org_forward = self.org_module.forward
self.org_module.forward = self.forward self.org_module.forward = self.forward
del self.org_module del self.org_module
def set_region(self, region): def set_region(self, region):
self.region = region self.region = region
self.region_mask = None self.region_mask = None
def forward(self, x): def forward(self, x):
if self.region is None: if self.region is None:
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
# regional LoRA FIXME same as additional-network extension # regional LoRA FIXME same as additional-network extension
if x.size()[1] % 77 == 0: if x.size()[1] % 77 == 0:
# print(f"LoRA for context: {self.lora_name}") # print(f"LoRA for context: {self.lora_name}")
self.region = None self.region = None
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
# calculate region mask first time # calculate region mask first time
if self.region_mask is None: if self.region_mask is None:
if len(x.size()) == 4: if len(x.size()) == 4:
h, w = x.size()[2:4] h, w = x.size()[2:4]
else: else:
seq_len = x.size()[1] seq_len = x.size()[1]
ratio = math.sqrt((self.region.size()[0] * self.region.size()[1]) / seq_len) ratio = math.sqrt((self.region.size()[0] * self.region.size()[1]) / seq_len)
h = int(self.region.size()[0] / ratio + .5) h = int(self.region.size()[0] / ratio + 0.5)
w = seq_len // h w = seq_len // h
r = self.region.to(x.device) r = self.region.to(x.device)
if r.dtype == torch.bfloat16: if r.dtype == torch.bfloat16:
r = r.to(torch.float) r = r.to(torch.float)
r = r.unsqueeze(0).unsqueeze(1) r = r.unsqueeze(0).unsqueeze(1)
# print(self.lora_name, self.region.size(), x.size(), r.size(), h, w) # print(self.lora_name, self.region.size(), x.size(), r.size(), h, w)
r = torch.nn.functional.interpolate(r, (h, w), mode='bilinear') r = torch.nn.functional.interpolate(r, (h, w), mode="bilinear")
r = r.to(x.dtype) r = r.to(x.dtype)
if len(x.size()) == 3: if len(x.size()) == 3:
r = torch.reshape(r, (1, x.size()[1], -1)) r = torch.reshape(r, (1, x.size()[1], -1))
self.region_mask = r self.region_mask = r
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale * self.region_mask return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale * self.region_mask
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
if network_dim is None: if network_dim is None:
network_dim = 4 # default network_dim = 4 # default
# extract dim/alpha for conv2d, and block dim # extract dim/alpha for conv2d, and block dim
conv_dim = kwargs.get('conv_dim', None) conv_dim = kwargs.get("conv_dim", None)
conv_alpha = kwargs.get('conv_alpha', None) conv_alpha = kwargs.get("conv_alpha", None)
if conv_dim is not None: if conv_dim is not None:
conv_dim = int(conv_dim) conv_dim = int(conv_dim)
if conv_alpha is None: if conv_alpha is None:
conv_alpha = 1.0 conv_alpha = 1.0
else: else:
conv_alpha = float(conv_alpha) conv_alpha = float(conv_alpha)
""" """
block_dims = kwargs.get("block_dims") block_dims = kwargs.get("block_dims")
block_alphas = None block_alphas = None
@@ -148,251 +148,276 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}" assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
""" """
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, network = LoRANetwork(
alpha=network_alpha, conv_lora_dim=conv_dim, conv_alpha=conv_alpha) text_encoder,
return network unet,
multiplier=multiplier,
lora_dim=network_dim,
alpha=network_alpha,
conv_lora_dim=conv_dim,
conv_alpha=conv_alpha,
)
return network
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, **kwargs): def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, **kwargs):
if weights_sd is None: if weights_sd is None:
if os.path.splitext(file)[1] == '.safetensors': if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file, safe_open from safetensors.torch import load_file, safe_open
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location='cpu')
# get dim/alpha mapping weights_sd = load_file(file)
modules_dim = {} else:
modules_alpha = {} weights_sd = torch.load(file, map_location="cpu")
for key, value in weights_sd.items():
if '.' not in key:
continue
lora_name = key.split('.')[0] # get dim/alpha mapping
if 'alpha' in key: modules_dim = {}
modules_alpha[lora_name] = value modules_alpha = {}
elif 'lora_down' in key: for key, value in weights_sd.items():
dim = value.size()[0] if "." not in key:
modules_dim[lora_name] = dim continue
# print(lora_name, value.size(), dim)
# support old LoRA without alpha lora_name = key.split(".")[0]
for key in modules_dim.keys(): if "alpha" in key:
if key not in modules_alpha: modules_alpha[lora_name] = value
modules_alpha = modules_dim[key] elif "lora_down" in key:
dim = value.size()[0]
modules_dim[lora_name] = dim
# print(lora_name, value.size(), dim)
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha) # support old LoRA without alpha
network.weights_sd = weights_sd for key in modules_dim.keys():
return network if key not in modules_alpha:
modules_alpha = modules_dim[key]
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
network.weights_sd = weights_sd
return network
class LoRANetwork(torch.nn.Module): class LoRANetwork(torch.nn.Module):
# is it possible to apply conv_in and conv_out? # is it possible to apply conv_in and conv_out?
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
LORA_PREFIX_UNET = 'lora_unet' LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = 'lora_te' LORA_PREFIX_TEXT_ENCODER = "lora_te"
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1, conv_lora_dim=None, conv_alpha=None, modules_dim=None, modules_alpha=None) -> None: def __init__(
super().__init__() self,
self.multiplier = multiplier text_encoder,
unet,
multiplier=1.0,
lora_dim=4,
alpha=1,
conv_lora_dim=None,
conv_alpha=None,
modules_dim=None,
modules_alpha=None,
) -> None:
super().__init__()
self.multiplier = multiplier
self.lora_dim = lora_dim self.lora_dim = lora_dim
self.alpha = alpha self.alpha = alpha
self.conv_lora_dim = conv_lora_dim self.conv_lora_dim = conv_lora_dim
self.conv_alpha = conv_alpha self.conv_alpha = conv_alpha
if modules_dim is not None: if modules_dim is not None:
print(f"create LoRA network from weights") print(f"create LoRA network from weights")
else: else:
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
self.apply_to_conv2d_3x3 = self.conv_lora_dim is not None self.apply_to_conv2d_3x3 = self.conv_lora_dim is not None
if self.apply_to_conv2d_3x3: if self.apply_to_conv2d_3x3:
if self.conv_alpha is None: if self.conv_alpha is None:
self.conv_alpha = self.alpha self.conv_alpha = self.alpha
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
# create module instances # create module instances
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]: def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
loras = [] loras = []
for name, module in root_module.named_modules(): for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules: if module.__class__.__name__ in target_replace_modules:
# TODO get block index here # TODO get block index here
for child_name, child_module in module.named_modules(): for child_name, child_module in module.named_modules():
is_linear = child_module.__class__.__name__ == "Linear" is_linear = child_module.__class__.__name__ == "Linear"
is_conv2d = child_module.__class__.__name__ == "Conv2d" is_conv2d = child_module.__class__.__name__ == "Conv2d"
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
if is_linear or is_conv2d: if is_linear or is_conv2d:
lora_name = prefix + '.' + name + '.' + child_name lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace('.', '_') lora_name = lora_name.replace(".", "_")
if modules_dim is not None: if modules_dim is not None:
if lora_name not in modules_dim: if lora_name not in modules_dim:
continue # no LoRA module in this weights file continue # no LoRA module in this weights file
dim = modules_dim[lora_name] dim = modules_dim[lora_name]
alpha = modules_alpha[lora_name] alpha = modules_alpha[lora_name]
else: else:
if is_linear or is_conv2d_1x1: if is_linear or is_conv2d_1x1:
dim = self.lora_dim dim = self.lora_dim
alpha = self.alpha alpha = self.alpha
elif self.apply_to_conv2d_3x3: elif self.apply_to_conv2d_3x3:
dim = self.conv_lora_dim dim = self.conv_lora_dim
alpha = self.conv_alpha alpha = self.conv_alpha
else: else:
continue continue
lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha) lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha)
loras.append(lora) loras.append(lora)
return loras return loras
self.text_encoder_loras = create_modules(LoRANetwork.LORA_PREFIX_TEXT_ENCODER, self.text_encoder_loras = create_modules(
text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) LoRANetwork.LORA_PREFIX_TEXT_ENCODER, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") )
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
if modules_dim is not None or self.conv_lora_dim is not None: if modules_dim is not None or self.conv_lora_dim is not None:
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, target_modules) self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, target_modules)
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
self.weights_sd = None self.weights_sd = None
# assertion # assertion
names = set() names = set()
for lora in self.text_encoder_loras + self.unet_loras: for lora in self.text_encoder_loras + self.unet_loras:
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name) names.add(lora.lora_name)
def set_multiplier(self, multiplier): def set_multiplier(self, multiplier):
self.multiplier = multiplier self.multiplier = multiplier
for lora in self.text_encoder_loras + self.unet_loras: for lora in self.text_encoder_loras + self.unet_loras:
lora.multiplier = self.multiplier lora.multiplier = self.multiplier
def load_weights(self, file): def load_weights(self, file):
if os.path.splitext(file)[1] == '.safetensors': if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file, safe_open from safetensors.torch import load_file, safe_open
self.weights_sd = load_file(file)
else:
self.weights_sd = torch.load(file, map_location='cpu')
def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None): self.weights_sd = load_file(file)
if self.weights_sd: else:
weights_has_text_encoder = weights_has_unet = False self.weights_sd = torch.load(file, map_location="cpu")
for key in self.weights_sd.keys():
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
weights_has_text_encoder = True
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
weights_has_unet = True
if apply_text_encoder is None: def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None):
apply_text_encoder = weights_has_text_encoder if self.weights_sd:
else: weights_has_text_encoder = weights_has_unet = False
assert apply_text_encoder == weights_has_text_encoder, f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています" for key in self.weights_sd.keys():
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
weights_has_text_encoder = True
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
weights_has_unet = True
if apply_unet is None: if apply_text_encoder is None:
apply_unet = weights_has_unet apply_text_encoder = weights_has_text_encoder
else: else:
assert apply_unet == weights_has_unet, f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています" assert (
else: apply_text_encoder == weights_has_text_encoder
assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set" ), f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています"
if apply_text_encoder: if apply_unet is None:
print("enable LoRA for text encoder") apply_unet = weights_has_unet
else: else:
self.text_encoder_loras = [] assert (
apply_unet == weights_has_unet
), f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています"
else:
assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set"
if apply_unet: if apply_text_encoder:
print("enable LoRA for U-Net") print("enable LoRA for text encoder")
else: else:
self.unet_loras = [] self.text_encoder_loras = []
for lora in self.text_encoder_loras + self.unet_loras: if apply_unet:
lora.apply_to() print("enable LoRA for U-Net")
self.add_module(lora.lora_name, lora) else:
self.unet_loras = []
if self.weights_sd: for lora in self.text_encoder_loras + self.unet_loras:
# if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros) lora.apply_to()
info = self.load_state_dict(self.weights_sd, False) self.add_module(lora.lora_name, lora)
print(f"weights are loaded: {info}")
def enable_gradient_checkpointing(self): if self.weights_sd:
# not supported # if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros)
pass info = self.load_state_dict(self.weights_sd, False)
print(f"weights are loaded: {info}")
def prepare_optimizer_params(self, text_encoder_lr, unet_lr): def enable_gradient_checkpointing(self):
def enumerate_params(loras): # not supported
params = [] pass
for lora in loras:
params.extend(lora.parameters())
return params
self.requires_grad_(True) def prepare_optimizer_params(self, text_encoder_lr, unet_lr):
all_params = [] def enumerate_params(loras):
params = []
for lora in loras:
params.extend(lora.parameters())
return params
if self.text_encoder_loras: self.requires_grad_(True)
param_data = {'params': enumerate_params(self.text_encoder_loras)} all_params = []
if text_encoder_lr is not None:
param_data['lr'] = text_encoder_lr
all_params.append(param_data)
if self.unet_loras: if self.text_encoder_loras:
param_data = {'params': enumerate_params(self.unet_loras)} param_data = {"params": enumerate_params(self.text_encoder_loras)}
if unet_lr is not None: if text_encoder_lr is not None:
param_data['lr'] = unet_lr param_data["lr"] = text_encoder_lr
all_params.append(param_data) all_params.append(param_data)
return all_params if self.unet_loras:
param_data = {"params": enumerate_params(self.unet_loras)}
if unet_lr is not None:
param_data["lr"] = unet_lr
all_params.append(param_data)
def prepare_grad_etc(self, text_encoder, unet): return all_params
self.requires_grad_(True)
def on_epoch_start(self, text_encoder, unet): def prepare_grad_etc(self, text_encoder, unet):
self.train() self.requires_grad_(True)
def get_trainable_params(self): def on_epoch_start(self, text_encoder, unet):
return self.parameters() self.train()
def save_weights(self, file, dtype, metadata): def get_trainable_params(self):
if metadata is not None and len(metadata) == 0: return self.parameters()
metadata = None
state_dict = self.state_dict() def save_weights(self, file, dtype, metadata):
if metadata is not None and len(metadata) == 0:
metadata = None
if dtype is not None: state_dict = self.state_dict()
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v
if os.path.splitext(file)[1] == '.safetensors': if dtype is not None:
from safetensors.torch import save_file for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v
# Precalculate model hashes to save time on indexing if os.path.splitext(file)[1] == ".safetensors":
if metadata is None: from safetensors.torch import save_file
metadata = {}
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
save_file(state_dict, file, metadata) # Precalculate model hashes to save time on indexing
else: if metadata is None:
torch.save(state_dict, file) metadata = {}
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
@ staticmethod save_file(state_dict, file, metadata)
def set_regions(networks, image): else:
image = image.astype(np.float32) / 255.0 torch.save(state_dict, file)
for i, network in enumerate(networks[:3]):
# NOTE: consider averaging overwrapping area
region = image[:, :, i]
if region.max() == 0:
continue
region = torch.tensor(region)
network.set_region(region)
def set_region(self, region): @staticmethod
for lora in self.unet_loras: def set_regions(networks, image):
lora.set_region(region) image = image.astype(np.float32) / 255.0
for i, network in enumerate(networks[:3]):
# NOTE: consider averaging overwrapping area
region = image[:, :, i]
if region.max() == 0:
continue
region = torch.tensor(region)
network.set_region(region)
def set_region(self, region):
for lora in self.unet_loras:
lora.set_region(region)