Skip to content

Commit 5db44d7

Browse files
marcenacpThe TensorFlow Datasets Authors
authored and
The TensorFlow Datasets Authors
committed
Read the length of the datasource from the FileInstructions to limit I/O.
PiperOrigin-RevId: 737687954
1 parent 27547b2 commit 5db44d7

File tree

3 files changed

+13
-13
lines changed

3 files changed

+13
-13
lines changed

Diff for: tensorflow_datasets/core/data_sources/array_record.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class ArrayRecordDataSource(base.BaseDataSource):
5656
length: int = dataclasses.field(init=False)
5757

5858
def __post_init__(self):
59-
file_instructions = base.file_instructions(self.dataset_info, self.split)
59+
file_instructions = self.split_info.file_instructions
6060
self.data_source = array_record_data_source.ArrayRecordDataSource(
6161
file_instructions
6262
)

Diff for: tensorflow_datasets/core/data_sources/base.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,6 @@ def __getitems__(self, keys: Iterable[int]) -> T:
4545
"""Returns the value for the given `keys`."""
4646

4747

48-
def file_instructions(
49-
dataset_info: dataset_info_lib.DatasetInfo,
50-
split: splits_lib.Split | None = None,
51-
) -> list[shard_utils.FileInstruction]:
52-
"""Retrieves the file instructions from the DatasetInfo."""
53-
split_infos = dataset_info.splits.values()
54-
split_dict = splits_lib.SplitDict(split_infos=split_infos)
55-
return split_dict[split].file_instructions
56-
57-
5848
@dataclasses.dataclass
5949
class BaseDataSource(MappingView, Sequence):
6050
"""Base DataSource to override all dunder methods with the deserialization.
@@ -94,6 +84,16 @@ def _deserialize(self, record: Any) -> Any:
9484
return features.deserialize_example_np(record, decoders=self.decoders) # pylint: disable=attribute-error
9585
raise ValueError('No features set, cannot decode example!')
9686

87+
@property
88+
def split_info(self) -> splits_lib.SplitInfo | splits_lib.SubSplitInfo:
89+
"""Returns the SplitInfo for the split."""
90+
splits = self.dataset_info.splits
91+
if self.split not in splits:
92+
raise ValueError(
93+
f'Split {self.split} not found in dataset {self.dataset_info.name}!'
94+
)
95+
return splits[self.split]
96+
9797
def __getitem__(self, key: SupportsIndex) -> Any:
9898
record = self.data_source[key.__index__()]
9999
return self._deserialize(record)
@@ -133,7 +133,7 @@ def __repr__(self) -> str:
133133
)
134134

135135
def __len__(self) -> int:
136-
return self.data_source.__len__()
136+
return sum(fi.examples_in_shard for fi in self.split_info.file_instructions)
137137

138138
def __iter__(self):
139139
for i in range(self.__len__()):

Diff for: tensorflow_datasets/core/data_sources/parquet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class ParquetDataSource(base.BaseDataSource):
5757
"""ParquetDataSource to read from a ParquetDataset."""
5858

5959
def __post_init__(self):
60-
file_instructions = base.file_instructions(self.dataset_info, self.split)
60+
file_instructions = self.split_info.file_instructions
6161
filenames = [
6262
file_instruction.filename for file_instruction in file_instructions
6363
]

0 commit comments

Comments
 (0)