Replace print with logger if they are logs (#905)

* Add get_my_logger()

* Use logger instead of print

* Fix log level

* Removed line-breaks for readability

* Use setup_logging()

* Add rich to requirements.txt

* Make simple

* Use logger instead of print

---------

Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
This commit is contained in:
Yuta Hayashibe
2024-02-04 16:14:34 +07:00
committed by GitHub
parent 7f948db158
commit 5f6bf29e52
62 changed files with 1195 additions and 961 deletions

View File

@@ -30,7 +30,10 @@ import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
from einops import rearrange
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
IN_CHANNELS: int = 4
OUT_CHANNELS: int = 4
@@ -332,7 +335,7 @@ class ResnetBlock2D(nn.Module):
def forward(self, x, emb):
if self.training and self.gradient_checkpointing:
# print("ResnetBlock2D: gradient_checkpointing")
# logger.info("ResnetBlock2D: gradient_checkpointing")
def create_custom_forward(func):
def custom_forward(*inputs):
@@ -366,7 +369,7 @@ class Downsample2D(nn.Module):
def forward(self, hidden_states):
if self.training and self.gradient_checkpointing:
# print("Downsample2D: gradient_checkpointing")
# logger.info("Downsample2D: gradient_checkpointing")
def create_custom_forward(func):
def custom_forward(*inputs):
@@ -653,7 +656,7 @@ class BasicTransformerBlock(nn.Module):
def forward(self, hidden_states, context=None, timestep=None):
if self.training and self.gradient_checkpointing:
# print("BasicTransformerBlock: checkpointing")
# logger.info("BasicTransformerBlock: checkpointing")
def create_custom_forward(func):
def custom_forward(*inputs):
@@ -796,7 +799,7 @@ class Upsample2D(nn.Module):
def forward(self, hidden_states, output_size=None):
if self.training and self.gradient_checkpointing:
# print("Upsample2D: gradient_checkpointing")
# logger.info("Upsample2D: gradient_checkpointing")
def create_custom_forward(func):
def custom_forward(*inputs):
@@ -1046,7 +1049,7 @@ class SdxlUNet2DConditionModel(nn.Module):
for block in blocks:
for module in block:
if hasattr(module, "set_use_memory_efficient_attention"):
# print(module.__class__.__name__)
# logger.info(module.__class__.__name__)
module.set_use_memory_efficient_attention(xformers, mem_eff)
def set_use_sdpa(self, sdpa: bool) -> None:
@@ -1061,7 +1064,7 @@ class SdxlUNet2DConditionModel(nn.Module):
for block in blocks:
for module in block.modules():
if hasattr(module, "gradient_checkpointing"):
# print(module.__class__.__name__, module.gradient_checkpointing, "->", value)
# logger.info(f{module.__class__.__name__} {module.gradient_checkpointing} -> {value}")
module.gradient_checkpointing = value
# endregion
@@ -1083,7 +1086,7 @@ class SdxlUNet2DConditionModel(nn.Module):
def call_module(module, h, emb, context):
x = h
for layer in module:
# print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
# logger.info(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
if isinstance(layer, ResnetBlock2D):
x = layer(x, emb)
elif isinstance(layer, Transformer2DModel):
@@ -1135,14 +1138,14 @@ class InferSdxlUNet2DConditionModel:
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
if ds_depth_1 is None:
print("Deep Shrink is disabled.")
logger.info("Deep Shrink is disabled.")
self.ds_depth_1 = None
self.ds_timesteps_1 = None
self.ds_depth_2 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
else:
print(
logger.info(
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
)
self.ds_depth_1 = ds_depth_1
@@ -1229,7 +1232,7 @@ class InferSdxlUNet2DConditionModel:
if __name__ == "__main__":
import time
print("create unet")
logger.info("create unet")
unet = SdxlUNet2DConditionModel()
unet.to("cuda")
@@ -1238,7 +1241,7 @@ if __name__ == "__main__":
unet.train()
# 使用メモリ量確認用の疑似学習ループ
print("preparing optimizer")
logger.info("preparing optimizer")
# optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working
@@ -1253,12 +1256,12 @@ if __name__ == "__main__":
scaler = torch.cuda.amp.GradScaler(enabled=True)
print("start training")
logger.info("start training")
steps = 10
batch_size = 1
for step in range(steps):
print(f"step {step}")
logger.info(f"step {step}")
if step == 1:
time_start = time.perf_counter()
@@ -1278,4 +1281,4 @@ if __name__ == "__main__":
optimizer.zero_grad(set_to_none=True)
time_end = time.perf_counter()
print(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps")
logger.info(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps")