diff --git a/library/train_util.py b/library/train_util.py index 642412dd..a1900609 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1208,7 +1208,7 @@ class BaseDataset(torch.utils.data.Dataset): if len(batch) >= caching_strategy.batch_size: submit_batch(batch, current_condition) batch = [] - # current_condition = None + # current_condition = None # keep current_condition to avoid next `clean_memory_on_device` call if len(batch) > 0: submit_batch(batch, current_condition) @@ -1771,14 +1771,10 @@ class BaseDataset(torch.utils.data.Dataset): tensors = [converter(x) for x in tensors] if tensors[0].ndim == 1: # input_ids or mask - result.append( - torch.stack([(torch.nn.functional.pad(x, (0, max_len - x.shape[0]))) for x in tensors]) - ) + result.append(torch.stack([(torch.nn.functional.pad(x, (0, max_len - x.shape[0]))) for x in tensors])) else: # text encoder outputs - result.append( - torch.stack([(torch.nn.functional.pad(x, (0, 0, 0, max_len - x.shape[0]))) for x in tensors]) - ) + result.append(torch.stack([(torch.nn.functional.pad(x, (0, 0, 0, max_len - x.shape[0]))) for x in tensors])) return result # set example