mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
doc: add comment for clarification
This commit is contained in:
@@ -1208,7 +1208,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
if len(batch) >= caching_strategy.batch_size:
|
if len(batch) >= caching_strategy.batch_size:
|
||||||
submit_batch(batch, current_condition)
|
submit_batch(batch, current_condition)
|
||||||
batch = []
|
batch = []
|
||||||
# current_condition = None
|
# current_condition = None # keep current_condition to avoid next `clean_memory_on_device` call
|
||||||
|
|
||||||
if len(batch) > 0:
|
if len(batch) > 0:
|
||||||
submit_batch(batch, current_condition)
|
submit_batch(batch, current_condition)
|
||||||
@@ -1771,14 +1771,10 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
tensors = [converter(x) for x in tensors]
|
tensors = [converter(x) for x in tensors]
|
||||||
if tensors[0].ndim == 1:
|
if tensors[0].ndim == 1:
|
||||||
# input_ids or mask
|
# input_ids or mask
|
||||||
result.append(
|
result.append(torch.stack([(torch.nn.functional.pad(x, (0, max_len - x.shape[0]))) for x in tensors]))
|
||||||
torch.stack([(torch.nn.functional.pad(x, (0, max_len - x.shape[0]))) for x in tensors])
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# text encoder outputs
|
# text encoder outputs
|
||||||
result.append(
|
result.append(torch.stack([(torch.nn.functional.pad(x, (0, 0, 0, max_len - x.shape[0]))) for x in tensors]))
|
||||||
torch.stack([(torch.nn.functional.pad(x, (0, 0, 0, max_len - x.shape[0]))) for x in tensors])
|
|
||||||
)
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# set example
|
# set example
|
||||||
|
|||||||
Reference in New Issue
Block a user