Official weights to LoRA

This commit is contained in:
Kohya S
2023-02-13 23:38:38 +09:00
parent bc9fc4ccee
commit cebee02698
5 changed files with 559 additions and 130 deletions

View File

@@ -826,14 +826,14 @@ class PipelineLike():
if isinstance(mask_image[0], PIL.Image.Image):
mask_image = torch.cat([preprocess_mask(im) for im in mask_image]) # H*W, 0 for repaint
# encode the init image into latents and scale the latents
init_image = init_image.to(device=self.device, dtype=latents_dtype)
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
if len(init_latents) == 1:
init_latents = init_latents.repeat((batch_size, 1, 1, 1))
init_latents_orig = init_latents
# # encode the init image into latents and scale the latents
# init_image = init_image.to(device=self.device, dtype=latents_dtype)
# init_latent_dist = self.vae.encode(init_image).latent_dist
# init_latents = init_latent_dist.sample(generator=generator)
# init_latents = 0.18215 * init_latents
# if len(init_latents) == 1:
# init_latents = init_latents.repeat((batch_size, 1, 1, 1))
# init_latents_orig = init_latents
# # preprocess mask
# if mask_image is not None:
@@ -846,7 +846,8 @@ class PipelineLike():
# raise ValueError("The mask and init_image should be the same size!")
# init imageをhintとして使う
hint_latents = init_latents
hint = init_image
# hint_latents = init_latents
# org_dtype = init_image.dtype
# hint = torch.nn.functional.interpolate(init_image.to(torch.float32), scale_factor=(1/8, 1/8), mode="bilinear")
# hint = hint[:, 0].unsqueeze(1) # RGB -> BW
@@ -876,7 +877,7 @@ class PipelineLike():
if accepts_eta:
extra_step_kwargs["eta"] = eta
hint_latents = torch.cat([hint_latents, hint_latents])
# hint_latents = torch.cat([hint_latents, hint_latents])
num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
for i, t in enumerate(tqdm(timesteps)):
@@ -885,11 +886,9 @@ class PipelineLike():
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
self.lora_network.set_as_control_path(True)
# self.unet(latent_model_input * hint, t, encoder_hidden_states=text_embeddings).sample
self.unet(hint_latents, t, encoder_hidden_states=text_embeddings)
self.lora_network.set_as_control_path(False)
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
noise_pred = self.lora_network.call_unet(self.unet, hint, latent_model_input, t, encoder_hidden_states=text_embeddings)[0] # .sample
# noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
@@ -1812,7 +1811,8 @@ def preprocess_image(image):
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
# return 2.0 * image - 1.0
return image # ControlNet
def preprocess_mask(mask):
@@ -2016,6 +2016,7 @@ def main(args):
if args.network_module:
networks = []
for i, network_module in enumerate(args.network_module):
# control_net_lora固定なのでimportする必要はないがとりあえず
print("import network module:", network_module)
imported_module = importlib.import_module(network_module)
@@ -2040,14 +2041,19 @@ def main(args):
metadata = f.metadata()
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")
from safetensors.torch import load_file
sd = load_file(network_weight)
network = imported_module.create_network_from_weights(network_mul, network_weight, vae, text_encoder, unet, **net_kwargs)
network = imported_module.ControlLoRANetwork(unet, sd, network_mul)
else:
raise ValueError("No weight. Weight is required.")
if network is None:
return
network.apply_to(text_encoder, unet)
network.apply_to() # text_encoder, unet)
info = network.load_state_dict(sd)
print(f"loading network: {info}")
if args.opt_channels_last:
network.to(memory_format=torch.channels_last)