From d1ea0073edf1919669f76c6163db171a38d0fb9f Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Sun, 18 Jan 2026 14:49:47 +0900 Subject: [PATCH] doc: add comment for clarification --- library/train_util.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 642412dd..a1900609 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1208,7 +1208,7 @@ class BaseDataset(torch.utils.data.Dataset): if len(batch) >= caching_strategy.batch_size: submit_batch(batch, current_condition) batch = [] - # current_condition = None + # current_condition = None # keep current_condition to avoid next `clean_memory_on_device` call if len(batch) > 0: submit_batch(batch, current_condition) @@ -1771,14 +1771,10 @@ class BaseDataset(torch.utils.data.Dataset): tensors = [converter(x) for x in tensors] if tensors[0].ndim == 1: # input_ids or mask - result.append( - torch.stack([(torch.nn.functional.pad(x, (0, max_len - x.shape[0]))) for x in tensors]) - ) + result.append(torch.stack([(torch.nn.functional.pad(x, (0, max_len - x.shape[0]))) for x in tensors])) else: # text encoder outputs - result.append( - torch.stack([(torch.nn.functional.pad(x, (0, 0, 0, max_len - x.shape[0]))) for x in tensors]) - ) + result.append(torch.stack([(torch.nn.functional.pad(x, (0, 0, 0, max_len - x.shape[0]))) for x in tensors])) return result # set example