Skip to content

Commit 0d6f250

Browse files
authored
Refactor Textsimilarity processor (deepset-ai#711)
* Refactor textsim processor * Add failsafe mode, add tests * Simplify dataset type checks * Clean processor saving and loading
1 parent 96e6786 commit 0d6f250

File tree

4 files changed

+424
-248
lines changed

4 files changed

+424
-248
lines changed

examples/dpr_encoder.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -72,19 +72,19 @@ def dense_passage_retrieval():
7272
# i.e., nq-train.json, nq-dev.json or trivia-train.json, trivia-dev.json
7373
label_list = ["hard_negative", "positive"]
7474
metric = "text_similarity_metric"
75-
processor = TextSimilarityProcessor(tokenizer=query_tokenizer,
76-
passage_tokenizer=passage_tokenizer,
77-
max_seq_len_query=64,
78-
max_seq_len_passage=256,
79-
label_list=label_list,
80-
metric=metric,
81-
data_dir="../data/retriever",
82-
train_filename=train_filename,
83-
dev_filename=dev_filename,
84-
test_filename=test_filename,
85-
embed_title=embed_title,
86-
num_hard_negatives=num_hard_negatives,
87-
max_samples=max_samples)
75+
processor = TextSimilarityProcessor(query_tokenizer=query_tokenizer,
76+
passage_tokenizer=passage_tokenizer,
77+
max_seq_len_query=64,
78+
max_seq_len_passage=256,
79+
label_list=label_list,
80+
metric=metric,
81+
data_dir="../data/retriever",
82+
train_filename=train_filename,
83+
dev_filename=dev_filename,
84+
test_filename=test_filename,
85+
embed_title=embed_title,
86+
num_hard_negatives=num_hard_negatives,
87+
max_samples=max_samples)
8888

8989
# 3. Create a DataSilo that loads several datasets (train/dev/test), provides DataLoaders for them and calculates a few descriptive statistics of our datasets
9090
# NOTE: In FARM, the dev set metrics differ from test set metrics in that they are calculated on a token level instead of a word level

farm/data_handler/dataset.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
import logging
44
import torch
55
from torch.utils.data import TensorDataset
6+
from collections.abc import Iterable
7+
from farm.utils import flatten_list
68

79
logger = logging.getLogger(__name__)
810

911

10-
# TODO we need the option to handle different dtypes
1112
def convert_features_to_dataset(features):
1213
"""
1314
Converts a list of feature dictionaries (one for each sample) into a PyTorch Dataset.
@@ -16,8 +17,7 @@ def convert_features_to_dataset(features):
1617
names of the type of feature and the keys are the features themselves.
1718
:Return: a Pytorch dataset and a list of tensor names.
1819
"""
19-
# features can be an empty list in cases where down sampling occurs (e.g. Natural Questions downsamples
20-
# instances of is_impossible
20+
# features can be an empty list in cases where down sampling occurs (e.g. Natural Questions downsamples instances of is_impossible)
2121
if len(features) == 0:
2222
return None, None
2323
tensor_names = list(features[0].keys())
@@ -29,15 +29,15 @@ def convert_features_to_dataset(features):
2929
else:
3030
try:
3131
# Checking weather a non-integer will be silently converted to torch.long
32-
if isinstance(features[0][t_name], numbers.Number):
33-
base = features[0][t_name]
34-
elif isinstance(features[0][t_name], list):
35-
if len(features[0][t_name]) > 0:
36-
base = features[0][t_name][0]
37-
else:
38-
base = 1
32+
check = features[0][t_name]
33+
if isinstance(check, numbers.Number):
34+
base = check
35+
# extract a base variable from a nested lists or tuples
36+
elif isinstance(check, Iterable):
37+
base = list(flatten_list(check))[0]
38+
# extract a base variable from numpy arrays
3939
else:
40-
base = features[0][t_name].ravel()[0]
40+
base = check.ravel()[0]
4141
if not np.issubdtype(type(base), np.integer):
4242
logger.warning(f"Problem during conversion to torch tensors:\n"
4343
f"A non-integer value for feature '{t_name}' with a value of: "
@@ -51,4 +51,4 @@ def convert_features_to_dataset(features):
5151
all_tensors.append(cur_tensor)
5252

5353
dataset = TensorDataset(*all_tensors)
54-
return dataset, tensor_names
54+
return dataset, tensor_names

0 commit comments

Comments
 (0)