mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Official weights to LoRA
This commit is contained in:
@@ -826,14 +826,14 @@ class PipelineLike():
|
||||
if isinstance(mask_image[0], PIL.Image.Image):
|
||||
mask_image = torch.cat([preprocess_mask(im) for im in mask_image]) # H*W, 0 for repaint
|
||||
|
||||
# encode the init image into latents and scale the latents
|
||||
init_image = init_image.to(device=self.device, dtype=latents_dtype)
|
||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||
init_latents = init_latent_dist.sample(generator=generator)
|
||||
init_latents = 0.18215 * init_latents
|
||||
if len(init_latents) == 1:
|
||||
init_latents = init_latents.repeat((batch_size, 1, 1, 1))
|
||||
init_latents_orig = init_latents
|
||||
# # encode the init image into latents and scale the latents
|
||||
# init_image = init_image.to(device=self.device, dtype=latents_dtype)
|
||||
# init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||
# init_latents = init_latent_dist.sample(generator=generator)
|
||||
# init_latents = 0.18215 * init_latents
|
||||
# if len(init_latents) == 1:
|
||||
# init_latents = init_latents.repeat((batch_size, 1, 1, 1))
|
||||
# init_latents_orig = init_latents
|
||||
|
||||
# # preprocess mask
|
||||
# if mask_image is not None:
|
||||
@@ -846,7 +846,8 @@ class PipelineLike():
|
||||
# raise ValueError("The mask and init_image should be the same size!")
|
||||
|
||||
# init imageをhintとして使う
|
||||
hint_latents = init_latents
|
||||
hint = init_image
|
||||
# hint_latents = init_latents
|
||||
# org_dtype = init_image.dtype
|
||||
# hint = torch.nn.functional.interpolate(init_image.to(torch.float32), scale_factor=(1/8, 1/8), mode="bilinear")
|
||||
# hint = hint[:, 0].unsqueeze(1) # RGB -> BW
|
||||
@@ -876,7 +877,7 @@ class PipelineLike():
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
hint_latents = torch.cat([hint_latents, hint_latents])
|
||||
# hint_latents = torch.cat([hint_latents, hint_latents])
|
||||
|
||||
num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
|
||||
for i, t in enumerate(tqdm(timesteps)):
|
||||
@@ -885,11 +886,9 @@ class PipelineLike():
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
self.lora_network.set_as_control_path(True)
|
||||
# self.unet(latent_model_input * hint, t, encoder_hidden_states=text_embeddings).sample
|
||||
self.unet(hint_latents, t, encoder_hidden_states=text_embeddings)
|
||||
self.lora_network.set_as_control_path(False)
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
noise_pred = self.lora_network.call_unet(self.unet, hint, latent_model_input, t, encoder_hidden_states=text_embeddings)[0] # .sample
|
||||
|
||||
# noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
@@ -1812,7 +1811,8 @@ def preprocess_image(image):
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
return 2.0 * image - 1.0
|
||||
# return 2.0 * image - 1.0
|
||||
return image # ControlNet
|
||||
|
||||
|
||||
def preprocess_mask(mask):
|
||||
@@ -2016,6 +2016,7 @@ def main(args):
|
||||
if args.network_module:
|
||||
networks = []
|
||||
for i, network_module in enumerate(args.network_module):
|
||||
# control_net_lora固定なのでimportする必要はないがとりあえず
|
||||
print("import network module:", network_module)
|
||||
imported_module = importlib.import_module(network_module)
|
||||
|
||||
@@ -2040,14 +2041,19 @@ def main(args):
|
||||
metadata = f.metadata()
|
||||
if metadata is not None:
|
||||
print(f"metadata for: {network_weight}: {metadata}")
|
||||
|
||||
from safetensors.torch import load_file
|
||||
sd = load_file(network_weight)
|
||||
|
||||
network = imported_module.create_network_from_weights(network_mul, network_weight, vae, text_encoder, unet, **net_kwargs)
|
||||
network = imported_module.ControlLoRANetwork(unet, sd, network_mul)
|
||||
else:
|
||||
raise ValueError("No weight. Weight is required.")
|
||||
if network is None:
|
||||
return
|
||||
|
||||
network.apply_to(text_encoder, unet)
|
||||
network.apply_to() # text_encoder, unet)
|
||||
info = network.load_state_dict(sd)
|
||||
print(f"loading network: {info}")
|
||||
|
||||
if args.opt_channels_last:
|
||||
network.to(memory_format=torch.channels_last)
|
||||
|
||||
Reference in New Issue
Block a user