Skip to content

Commit 9dd7c74

Browse files
authored
Refactor communication between Pipeline Components (deepset-ai#1321)
1 parent 3e6def7 commit 9dd7c74

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+8445
-8319
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ A minimal Open-Domain QA Pipeline:
512512
p = Pipeline()
513513
p.add_node(component=retriever, name="ESRetriever1", inputs=["Query"])
514514
p.add_node(component=reader, name="QAReader", inputs=["ESRetriever1"])
515-
res = p.run(query="What did Einstein work on?", top_k_retriever=1)
515+
res = p.run(query="What did Einstein work on?", params={"retriever": {"top_k": 1}})
516516

517517
```
518518
You can **draw the DAG** to inspect better what you are building:

docs/_src/usage/usage/pipelines.md

+27-26
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ from haystack import Pipeline
2121
p = Pipeline()
2222
p.add_node(component=retriever, name="ESRetriever1", inputs=["Query"])
2323
p.add_node(component=reader, name="QAReader", inputs=["ESRetriever1"])
24-
res = p.run(query="What did Einstein work on?", top_k_retriever=1)
24+
res = p.run(query="What did Einstein work on?")
2525
```
2626

2727
You can **draw the DAG** to better inspect what you are building:
@@ -32,16 +32,16 @@ p.draw(path="custom_pipe.png")
3232

3333
### Arguments
3434

35-
Whatever keyword arguments are passed into the `Pipeline.run()` method will be passed on to each node in the pipeline.
36-
For example, in the code snippet below, all nodes will receive `query`, `top_k_retriever` and `top_k_reader` as argument,
37-
even if they don't use those arguments. It is therefore very important when defining custom nodes that their
38-
keyword argument names do not clash with the other nodes in your pipeline.
35+
Each node in a Pipeline defines the arguments the run() method accepts. The Pipeline class takes care of passing relevant
36+
arguments to the node. In addition to mandatory inputs like `query`, the `run()` accepts optional node parameters like
37+
`top_k` with the `params` argument. For instance, `params={"top_k": 5}` will set the `top_k` of all nodes as 5. To
38+
target params to a specific node, the node name can be explicitly specifie as `params={"Retriever": {"top_k": 5}}`.
39+
3940

4041
```python
4142
res = pipeline.run(
4243
query="What did Einstein work on?",
43-
top_k_retriever=1,
44-
top_k_reader=5
44+
params={"Retriever": {"top_k": 5}, "Reader": {"top_k": 3}}
4545
)
4646
```
4747

@@ -95,38 +95,39 @@ For another example YAML config, check out [this file](https://github.com/deepse
9595
### Multiple retrievers
9696
You can now also use multiple Retrievers and join their results:
9797
```python
98-
from haystack import Pipeline
99-
10098
p = Pipeline()
10199
p.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
102100
p.add_node(component=dpr_retriever, name="DPRRetriever", inputs=["Query"])
103101
p.add_node(component=JoinDocuments(join_mode="concatenate"), name="JoinResults", inputs=["ESRetriever", "DPRRetriever"])
104102
p.add_node(component=reader, name="QAReader", inputs=["JoinResults"])
105-
res = p.run(query="What did Einstein work on?", top_k_retriever=1)
103+
res = p.run(query="What did Einstein work on?", params={"ESRetriever": {"top_k": 1}, "DPRRetriever": {"top_k": 3}})
106104
```
107105
![image](https://user-images.githubusercontent.com/1563902/102451782-7bd80400-4039-11eb-9046-01b002a783f8.png)
108106

109107
### Custom nodes
110-
You can easily build your own custom nodes. Just respect the following requirements:
108+
It is easy to build custom nodes. Just respect the following requirements:
111109

112-
1. Add a method `run(self, **kwargs)` to your class. `**kwargs` will contain the output from the previous node in your graph.
113-
2. Do whatever you want within `run()` (e.g. reformatting the query)
114-
3. Return a tuple that contains your output data (for the next node) and the name of the outgoing edge `output_dict, "output_1`
115-
4. Add a class attribute `outgoing_edges = 1` that defines the number of output options from your node. You only need a higher number here if you have a decision node (see below).
110+
1. Create a Class that inherits from `BaseComponent`.
111+
2. Add a `run()` method to your class with any parameters it needs to process the input. Ensure that the parameters are either passed with `params` to the pipeline or are returned by the preceding nodes.
112+
3. Do whatever you want within `run()` (e.g., reformatting the query).
113+
4. Return a tuple that contains your output data (for the next node) and the name of the outgoing edge `output_dict, "output_1`.
114+
5. Add a class attribute `outgoing_edges = 1` that defines your node's number of output options. You only need a higher number here if you have a decision node (see below).
116115

117116
### Decision nodes
118117
Or you can add decision nodes where only one "branch" is executed afterwards. This allows, for example, to classify an incoming query and depending on the result routing it to different modules:
119118
![image](https://user-images.githubusercontent.com/1563902/102452199-41229b80-403a-11eb-9365-7038697e7c3e.png)
120-
```python
121-
class QueryClassifier():
119+
```python
120+
from haystack import BaseComponent, Pipeline
121+
122+
class QueryClassifier(BaseComponent):
122123
outgoing_edges = 2
123124

124-
def run(self, **kwargs):
125-
if "?" in kwargs["query"]:
126-
return (kwargs, "output_1")
125+
def run(self, query):
126+
if "?" in query:
127+
return {}, "output_1"
127128

128129
else:
129-
return (kwargs, "output_2")
130+
return {}, "output_2"
130131

131132
pipe = Pipeline()
132133
pipe.add_node(component=QueryClassifier(), name="QueryClassifier", inputs=["Query"])
@@ -135,7 +136,7 @@ Or you can add decision nodes where only one "branch" is executed afterwards. Th
135136
pipe.add_node(component=JoinDocuments(join_mode="concatenate"), name="JoinResults",
136137
inputs=["ESRetriever", "DPRRetriever"])
137138
pipe.add_node(component=reader, name="QAReader", inputs=["JoinResults"])
138-
res = p.run(query="What did Einstein work on?", top_k_retriever=1)
139+
res = p.run(query="What did Einstein work on?", params={"ESRetriever": {"top_k": 1}, "DPRRetriever": {"top_k": 3}})
139140
```
140141

141142
### Evaluation nodes
@@ -152,19 +153,19 @@ from haystack.pipeline import DocumentSearchPipeline, ExtractiveQAPipeline, Pipe
152153
153154
# Extractive QA
154155
qa_pipe = ExtractiveQAPipeline(reader=reader, retriever=retriever)
155-
res = qa_pipe.run(query="When was Kant born?", top_k_retriever=3, top_k_reader=5)
156+
res = qa_pipe.run(query="When was Kant born?", params={"retriever": {"top_k": 3}, "reader": {"top_k": 5}})
156157
157158
# Document Search
158159
doc_pipe = DocumentSearchPipeline(retriever=retriever)
159-
res = doc_pipe.run(query="Physics Einstein", top_k_retriever=1)
160+
res = doc_pipe.run(query="Physics Einstein", params={"retriever": {"top_k": 3}})
160161
161162
# Generative QA
162163
doc_pipe = GenerativeQAPipeline(generator=rag_generator, retriever=retriever)
163-
res = doc_pipe.run(query="Physics Einstein", top_k_retriever=1)
164+
res = doc_pipe.run(query="Physics Einstein", params={"retriever": {"top_k": 3}})
164165
165166
# FAQ based QA
166167
doc_pipe = FAQPipeline(retriever=retriever)
167-
res = doc_pipe.run(query="How can I change my address?", top_k_retriever=3)
168+
res = doc_pipe.run(query="How can I change my address?", params={"retriever": {"top_k": 3}})
168169
169170
```
170171
So to migrate your QA system from the deprecated `Finder` to `ExtractiveQAPipeline` you'd need to:

haystack/classifier/base.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def predict(self, query: str, documents: List[Document], top_k: Optional[int] =
2626
def predict_batch(self, query_doc_list: List[dict], top_k: Optional[int] = None, batch_size: Optional[int] = None):
2727
pass
2828

29-
def run(self, query: str, documents: List[Document], top_k: Optional[int] = None, **kwargs): # type: ignore
29+
def run(self, query: str, documents: List[Document], top_k: Optional[int] = None): # type: ignore
3030
self.query_count += 1
3131
if documents:
3232
predict = self.timing(self.predict, "query_time")
@@ -36,11 +36,7 @@ def run(self, query: str, documents: List[Document], top_k: Optional[int] = None
3636

3737
document_ids = [doc.id for doc in results]
3838
logger.debug(f"Retrieved documents with IDs: {document_ids}")
39-
output = {
40-
"query": query,
41-
"documents": results,
42-
**kwargs
43-
}
39+
output = {"documents": results}
4440

4541
return output, "output_1"
4642

haystack/classifier/farm.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,12 @@ class FARMClassifier(BaseClassifier):
3131
retriever = ElasticsearchRetriever(document_store=document_store)
3232
classifier = FARMClassifier(model_name_or_path="deepset/bert-base-german-cased-sentiment-Germeval17")
3333
p = Pipeline()
34-
p.add_node(component=retriever, name="ESRetriever", inputs=["Query"])
35-
p.add_node(component=classifier, name="Classifier", inputs=["ESRetriever"])
34+
p.add_node(component=retriever, name="Retriever", inputs=["Query"])
35+
p.add_node(component=classifier, name="Classifier", inputs=["Retriever"])
3636
37-
res = p_extractive.run(
37+
res = p.run(
3838
query="Who is the father of Arya Stark?",
39-
top_k_retriever=10,
40-
top_k_reader=5
39+
params={"Retriever": {"top_k": 10}, "Classifier": {"top_k": 5}}
4140
)
4241
4342
print(res["documents"][0].to_dict()["meta"]["classification"]["label"])

haystack/connector/crawler.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,15 @@ def _write_to_files(self, urls: List[str], output_dir: Path, base_url: str = Non
145145

146146
return paths
147147

148-
def run(self, output_dir: Union[str, Path, None] = None, urls: Optional[List[str]] = None, # type: ignore
149-
crawler_depth: Optional[int] = None, filter_urls: Optional[List] = None, # type: ignore
150-
overwrite_existing_files: Optional[bool] = None, return_documents: Optional[bool] = False, # type: ignore
151-
**kwargs) -> Tuple[Dict, str]: # type: ignore
148+
def run( # type: ignore
149+
self,
150+
output_dir: Union[str, Path, None] = None,
151+
urls: Optional[List[str]] = None,
152+
crawler_depth: Optional[int] = None,
153+
filter_urls: Optional[List] = None,
154+
overwrite_existing_files: Optional[bool] = None,
155+
return_documents: Optional[bool] = False,
156+
) -> Tuple[Dict, str]:
152157
"""
153158
Method to be executed when the Crawler is used as a Node within a Haystack pipeline.
154159
@@ -172,7 +177,7 @@ def run(self, output_dir: Union[str, Path, None] = None, urls: Optional[List[str
172177
results = {"documents": crawled_data}
173178
else:
174179
results = {"paths": file_paths}
175-
results.update(**kwargs)
180+
176181
return results, "output_1"
177182

178183
@staticmethod

haystack/document_store/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -286,9 +286,9 @@ def delete_all_documents(self, index: Optional[str] = None, filters: Optional[Di
286286
def delete_documents(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None):
287287
pass
288288

289-
def run(self, documents: List[dict], index: Optional[str] = None, **kwargs): # type: ignore
289+
def run(self, documents: List[dict], index: Optional[str] = None): # type: ignore
290290
self.write_documents(documents=documents, index=index)
291-
return kwargs, "output_1"
291+
return {}, "output_1"
292292

293293
@abstractmethod
294294
def get_documents_by_id(self, ids: List[str], index: Optional[str] = None,

haystack/eval.py

+28-26
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
from sklearn.metrics.pairwise import cosine_similarity
66
import numpy as np
77

8-
from haystack import MultiLabel, Label
8+
from haystack import MultiLabel, Label, BaseComponent, Document
99

1010
from farm.evaluation.squad_evaluation import compute_f1 as calculate_f1_str
1111
from farm.evaluation.squad_evaluation import compute_exact as calculate_em_str
1212

1313
logger = logging.getLogger(__name__)
1414

1515

16-
class EvalDocuments:
16+
class EvalDocuments(BaseComponent):
1717
"""
1818
This is a pipeline node that should be placed after a node that returns a List of Document, e.g., Retriever or
1919
Ranker, in order to assess its performance. Performance metrics are stored in this class and updated as each
@@ -22,21 +22,22 @@ class EvalDocuments:
2222
a look at our evaluation tutorial for more info about open vs closed domain eval (
2323
https://haystack.deepset.ai/tutorials/evaluation).
2424
"""
25-
def __init__(self, debug: bool=False, open_domain: bool=True, top_k_eval_documents: int=10, name="EvalDocuments"):
25+
26+
outgoing_edges = 1
27+
28+
def __init__(self, debug: bool=False, open_domain: bool=True, top_k: int=10):
2629
"""
2730
:param open_domain: When True, a document is considered correctly retrieved so long as the answer string can be found within it.
2831
When False, correct retrieval is evaluated based on document_id.
2932
:param debug: When True, a record of each sample and its evaluation will be stored in EvalDocuments.log
3033
:param top_k: calculate eval metrics for top k results, e.g., recall@k
3134
"""
32-
self.outgoing_edges = 1
3335
self.init_counts()
3436
self.no_answer_warning = False
3537
self.debug = debug
3638
self.log: List = []
3739
self.open_domain = open_domain
38-
self.top_k_eval_documents = top_k_eval_documents
39-
self.name = name
40+
self.top_k = top_k
4041
self.too_few_docs_warning = False
4142
self.top_k_used = 0
4243

@@ -53,25 +54,25 @@ def init_counts(self):
5354
self.reciprocal_rank_sum = 0.0
5455
self.has_answer_reciprocal_rank_sum = 0.0
5556

56-
def run(self, documents, labels: dict, top_k_eval_documents: Optional[int]=None, **kwargs):
57+
def run(self, documents: List[Document], labels: List[Label], top_k: Optional[int] = None): # type: ignore
5758
"""Run this node on one sample and its labels"""
5859
self.query_count += 1
59-
retriever_labels = get_label(labels, kwargs["node_id"])
60-
if not top_k_eval_documents:
61-
top_k_eval_documents = self.top_k_eval_documents
60+
retriever_labels = get_label(labels, self.name)
61+
if not top_k:
62+
top_k = self.top_k
6263

6364
if not self.top_k_used:
64-
self.top_k_used = top_k_eval_documents
65-
elif self.top_k_used != top_k_eval_documents:
65+
self.top_k_used = top_k
66+
elif self.top_k_used != top_k:
6667
logger.warning(f"EvalDocuments was last run with top_k_eval_documents={self.top_k_used} but is "
67-
f"being run again with top_k_eval_documents={self.top_k_eval_documents}. "
68+
f"being run again with top_k={self.top_k}. "
6869
f"The evaluation counter is being reset from this point so that the evaluation "
6970
f"metrics are interpretable.")
7071
self.init_counts()
7172

72-
if len(documents) < top_k_eval_documents and not self.too_few_docs_warning:
73-
logger.warning(f"EvalDocuments is being provided less candidate documents than top_k_eval_documents "
74-
f"(currently set to {top_k_eval_documents}).")
73+
if len(documents) < top_k and not self.too_few_docs_warning:
74+
logger.warning(f"EvalDocuments is being provided less candidate documents than top_k "
75+
f"(currently set to {top_k}).")
7576
self.too_few_docs_warning = True
7677

7778
# TODO retriever_labels is currently a Multilabel object but should eventually be a RetrieverLabel object
@@ -89,7 +90,7 @@ def run(self, documents, labels: dict, top_k_eval_documents: Optional[int]=None,
8990
# If there are answer span annotations in the labels
9091
else:
9192
self.has_answer_count += 1
92-
retrieved_reciprocal_rank = self.reciprocal_rank_retrieved(retriever_labels, documents, top_k_eval_documents)
93+
retrieved_reciprocal_rank = self.reciprocal_rank_retrieved(retriever_labels, documents, top_k)
9394
self.reciprocal_rank_sum += retrieved_reciprocal_rank
9495
correct_retrieval = True if retrieved_reciprocal_rank > 0 else False
9596
self.has_answer_correct += int(correct_retrieval)
@@ -101,11 +102,11 @@ def run(self, documents, labels: dict, top_k_eval_documents: Optional[int]=None,
101102
self.recall = self.correct_retrieval_count / self.query_count
102103
self.mean_reciprocal_rank = self.reciprocal_rank_sum / self.query_count
103104

104-
self.top_k_used = top_k_eval_documents
105+
self.top_k_used = top_k
105106

106107
if self.debug:
107-
self.log.append({"documents": documents, "labels": labels, "correct_retrieval": correct_retrieval, "retrieved_reciprocal_rank": retrieved_reciprocal_rank, **kwargs})
108-
return {"documents": documents, "labels": labels, "correct_retrieval": correct_retrieval, "retrieved_reciprocal_rank": retrieved_reciprocal_rank, **kwargs}, "output_1"
108+
self.log.append({"documents": documents, "labels": labels, "correct_retrieval": correct_retrieval, "retrieved_reciprocal_rank": retrieved_reciprocal_rank})
109+
return {"correct_retrieval": correct_retrieval}, "output_1"
109110

110111
def is_correctly_retrieved(self, retriever_labels, predictions):
111112
return self.reciprocal_rank_retrieved(retriever_labels, predictions) > 0
@@ -142,7 +143,7 @@ def print(self):
142143
print(f"mean_reciprocal_rank@{self.top_k_used}: {self.mean_reciprocal_rank:.4f}")
143144

144145

145-
class EvalAnswers:
146+
class EvalAnswers(BaseComponent):
146147
"""
147148
This is a pipeline node that should be placed after a Reader in order to assess the performance of the Reader
148149
individually or to assess the extractive QA performance of the whole pipeline. Performance metrics are stored in
@@ -152,6 +153,8 @@ class EvalAnswers:
152153
open vs closed domain eval (https://haystack.deepset.ai/tutorials/evaluation).
153154
"""
154155

156+
outgoing_edges = 1
157+
155158
def __init__(self,
156159
skip_incorrect_retrieval: bool = True,
157160
open_domain: bool = True,
@@ -174,7 +177,6 @@ def __init__(self,
174177
- Large model for German only: "deepset/gbert-large-sts"
175178
:param debug: When True, a record of each sample and its evaluation will be stored in EvalAnswers.log
176179
"""
177-
self.outgoing_edges = 1
178180
self.log: List = []
179181
self.debug = debug
180182
self.skip_incorrect_retrieval = skip_incorrect_retrieval
@@ -203,14 +205,14 @@ def init_counts(self):
203205
self.top_1_sas = 0.0
204206
self.top_k_sas = 0.0
205207

206-
def run(self, labels, answers, **kwargs):
208+
def run(self, labels: List[Label], answers: List[dict], correct_retrieval: bool): # type: ignore
207209
"""Run this node on one sample and its labels"""
208210
self.query_count += 1
209211
predictions = answers
210-
skip = self.skip_incorrect_retrieval and not kwargs.get("correct_retrieval")
212+
skip = self.skip_incorrect_retrieval and not correct_retrieval
211213
if predictions and not skip:
212214
self.correct_retrieval_count += 1
213-
multi_labels = get_label(labels, kwargs["node_id"])
215+
multi_labels = get_label(labels, self.name)
214216
# If this sample is impossible to answer and expects a no_answer response
215217
if multi_labels.no_answer:
216218
self.no_answer_count += 1
@@ -254,7 +256,7 @@ def run(self, labels, answers, **kwargs):
254256
self.top_k_em_count += top_k_em
255257
self.top_k_f1_sum += top_k_f1
256258
self.update_has_answer_metrics()
257-
return {**kwargs}, "output_1"
259+
return {}, "output_1"
258260

259261
def evaluate_extraction(self, gold_labels, predictions):
260262
if self.open_domain:

0 commit comments

Comments
 (0)