@@ -45,16 +45,6 @@ def __getitems__(self, keys: Iterable[int]) -> T:
45
45
"""Returns the value for the given `keys`."""
46
46
47
47
48
- def file_instructions (
49
- dataset_info : dataset_info_lib .DatasetInfo ,
50
- split : splits_lib .Split | None = None ,
51
- ) -> list [shard_utils .FileInstruction ]:
52
- """Retrieves the file instructions from the DatasetInfo."""
53
- split_infos = dataset_info .splits .values ()
54
- split_dict = splits_lib .SplitDict (split_infos = split_infos )
55
- return split_dict [split ].file_instructions
56
-
57
-
58
48
@dataclasses .dataclass
59
49
class BaseDataSource (MappingView , Sequence ):
60
50
"""Base DataSource to override all dunder methods with the deserialization.
@@ -94,6 +84,16 @@ def _deserialize(self, record: Any) -> Any:
94
84
return features .deserialize_example_np (record , decoders = self .decoders ) # pylint: disable=attribute-error
95
85
raise ValueError ('No features set, cannot decode example!' )
96
86
87
+ @property
88
+ def split_info (self ) -> splits_lib .SplitInfo | splits_lib .SubSplitInfo :
89
+ """Returns the SplitInfo for the split."""
90
+ splits = self .dataset_info .splits
91
+ if self .split not in splits :
92
+ raise ValueError (
93
+ f'Split { self .split } not found in dataset { self .dataset_info .name } !'
94
+ )
95
+ return splits [self .split ]
96
+
97
97
def __getitem__ (self , key : SupportsIndex ) -> Any :
98
98
record = self .data_source [key .__index__ ()]
99
99
return self ._deserialize (record )
@@ -133,7 +133,7 @@ def __repr__(self) -> str:
133
133
)
134
134
135
135
def __len__ (self ) -> int :
136
- return self .data_source . __len__ ( )
136
+ return sum ( fi . examples_in_shard for fi in self .split_info . file_instructions )
137
137
138
138
def __iter__ (self ):
139
139
for i in range (self .__len__ ()):
0 commit comments