doc: add comment for clarification

This commit is contained in:
Kohya S
2026-01-18 14:49:47 +09:00
parent 555d745ba1
commit d1ea0073ed

View File

@@ -1208,7 +1208,7 @@ class BaseDataset(torch.utils.data.Dataset):
if len(batch) >= caching_strategy.batch_size:
submit_batch(batch, current_condition)
batch = []
# current_condition = None
# current_condition = None # keep current_condition to avoid next `clean_memory_on_device` call
if len(batch) > 0:
submit_batch(batch, current_condition)
@@ -1771,14 +1771,10 @@ class BaseDataset(torch.utils.data.Dataset):
tensors = [converter(x) for x in tensors]
if tensors[0].ndim == 1:
# input_ids or mask
result.append(
torch.stack([(torch.nn.functional.pad(x, (0, max_len - x.shape[0]))) for x in tensors])
)
result.append(torch.stack([(torch.nn.functional.pad(x, (0, max_len - x.shape[0]))) for x in tensors]))
else:
# text encoder outputs
result.append(
torch.stack([(torch.nn.functional.pad(x, (0, 0, 0, max_len - x.shape[0]))) for x in tensors])
)
result.append(torch.stack([(torch.nn.functional.pad(x, (0, 0, 0, max_len - x.shape[0]))) for x in tensors]))
return result
# set example