Replace print with logger if they are logs (#905)

* Add get_my_logger()

* Use logger instead of print

* Fix log level

* Removed line-breaks for readability

* Use setup_logging()

* Add rich to requirements.txt

* Make simple

* Use logger instead of print

---------

Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
This commit is contained in:
Yuta Hayashibe
2024-02-04 16:14:34 +07:00
committed by GitHub
parent 7f948db158
commit 5f6bf29e52
62 changed files with 1195 additions and 961 deletions

View File

@@ -23,6 +23,10 @@ from safetensors.torch import load_file
from library import model_util, sdxl_model_util
import networks.lora as lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# scheduler: このあたりの設定はSD1/2と同じでいいらしい
# scheduler: The settings around here seem to be the same as SD1/2
@@ -140,7 +144,7 @@ if __name__ == "__main__":
vae_dtype = DTYPE
if DTYPE == torch.float16:
print("use float32 for vae")
logger.info("use float32 for vae")
vae_dtype = torch.float32
vae.to(DEVICE, dtype=vae_dtype)
vae.eval()
@@ -187,7 +191,7 @@ if __name__ == "__main__":
emb1 = get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256)
emb2 = get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256)
emb3 = get_timestep_embedding(torch.FloatTensor([target_height, target_width]).unsqueeze(0), 256)
# print("emb1", emb1.shape)
# logger.info("emb1", emb1.shape)
c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(DEVICE, dtype=DTYPE)
uc_vector = c_vector.clone().to(DEVICE, dtype=DTYPE) # ちょっとここ正しいかどうかわからない I'm not sure if this is right
@@ -217,7 +221,7 @@ if __name__ == "__main__":
enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True)
text_embedding2_penu = enc_out["hidden_states"][-2]
# print("hidden_states2", text_embedding2_penu.shape)
# logger.info("hidden_states2", text_embedding2_penu.shape)
text_embedding2_pool = enc_out["text_embeds"] # do not support Textual Inversion
# 連結して終了 concat and finish
@@ -226,7 +230,7 @@ if __name__ == "__main__":
# cond
c_ctx, c_ctx_pool = call_text_encoder(prompt, prompt2)
# print(c_ctx.shape, c_ctx_p.shape, c_vector.shape)
# logger.info(c_ctx.shape, c_ctx_p.shape, c_vector.shape)
c_vector = torch.cat([c_ctx_pool, c_vector], dim=1)
# uncond
@@ -323,4 +327,4 @@ if __name__ == "__main__":
seed = int(seed)
generate_image(prompt, prompt2, negative_prompt, seed)
print("Done!")
logger.info("Done!")