mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
update dataset to return size, refactor ctrlnet ds
This commit is contained in:
@@ -1116,13 +1116,15 @@ if __name__ == "__main__":
|
||||
# 使用メモリ量確認用の疑似学習ループ
|
||||
print("preparing optimizer")
|
||||
|
||||
import bitsandbytes
|
||||
import transformers
|
||||
# optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working
|
||||
|
||||
# import bitsandbytes
|
||||
# optimizer = bitsandbytes.adam.Adam8bit(unet.parameters(), lr=1e-3) # not working
|
||||
# optimizer = bitsandbytes.optim.RMSprop8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2
|
||||
# optimizer=bitsandbytes.optim.Adagrad8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2
|
||||
# optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working
|
||||
|
||||
import transformers
|
||||
|
||||
optimizer = transformers.optimization.Adafactor(unet.parameters(), relative_step=True) # working at 22.2GB with torch2
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=True)
|
||||
@@ -1133,7 +1135,7 @@ if __name__ == "__main__":
|
||||
for step in range(steps):
|
||||
print(f"step {step}")
|
||||
|
||||
x = torch.randn(batch_size, 4, 128, 128).cuda() # 512x512
|
||||
x = torch.randn(batch_size, 4, 128, 128).cuda() # 1024x1024
|
||||
t = torch.randint(low=0, high=10, size=(batch_size,), device="cuda")
|
||||
ctx = torch.randn(batch_size, 77, 2048).cuda()
|
||||
y = torch.randn(batch_size, ADM_IN_CHANNELS).cuda()
|
||||
|
||||
Reference in New Issue
Block a user