mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
remove workaround for transfomers bs>1 close #869
This commit is contained in:
@@ -52,6 +52,9 @@ def collate_fn_remove_corrupted(batch):
|
|||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
|
r"""
|
||||||
|
transformers 4.30.2で、バッチサイズ>1でも動くようになったので、以下コメントアウト
|
||||||
|
|
||||||
# GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用
|
# GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用
|
||||||
org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation
|
org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation
|
||||||
curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように
|
curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように
|
||||||
@@ -65,6 +68,7 @@ def main(args):
|
|||||||
return input_ids
|
return input_ids
|
||||||
|
|
||||||
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
|
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
|
||||||
|
"""
|
||||||
|
|
||||||
print(f"load images from {args.train_data_dir}")
|
print(f"load images from {args.train_data_dir}")
|
||||||
train_data_dir_path = Path(args.train_data_dir)
|
train_data_dir_path = Path(args.train_data_dir)
|
||||||
@@ -81,7 +85,7 @@ def main(args):
|
|||||||
def run_batch(path_imgs):
|
def run_batch(path_imgs):
|
||||||
imgs = [im for _, im in path_imgs]
|
imgs = [im for _, im in path_imgs]
|
||||||
|
|
||||||
curr_batch_size[0] = len(path_imgs)
|
# curr_batch_size[0] = len(path_imgs)
|
||||||
inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式
|
inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式
|
||||||
generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length)
|
generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length)
|
||||||
captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True)
|
captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user