diff --git a/presidio_evaluator/models/base_model.py b/presidio_evaluator/models/base_model.py index ae99780..bd07658 100644 --- a/presidio_evaluator/models/base_model.py +++ b/presidio_evaluator/models/base_model.py @@ -33,7 +33,7 @@ def __init__( self.verbose = verbose @abstractmethod - def predict(self, sample: InputSample) -> List[str]: + def predict(self, sample: InputSample, **kwargs) -> List[str]: """ Abstract. Returns the predicted tokens/spans from the evaluated model :param sample: Sample to be evaluated diff --git a/presidio_evaluator/models/presidio_analyzer_wrapper.py b/presidio_evaluator/models/presidio_analyzer_wrapper.py index aa85abd..9598b2b 100644 --- a/presidio_evaluator/models/presidio_analyzer_wrapper.py +++ b/presidio_evaluator/models/presidio_analyzer_wrapper.py @@ -1,6 +1,6 @@ from typing import List, Optional, Dict -from presidio_analyzer import AnalyzerEngine +from presidio_analyzer import AnalyzerEngine, EntityRecognizer from presidio_evaluator import InputSample, span_to_tag from presidio_evaluator.models import BaseModel @@ -16,6 +16,9 @@ def __init__( score_threshold: float = 0.4, language: str = "en", entity_mapping: Optional[Dict[str, str]] = None, + ad_hoc_recognizers: Optional[List[EntityRecognizer]] = None, + context: Optional[List[str]] = None, + allow_list: Optional[List[str]] = None, ): """ Evaluation wrapper for the Presidio Analyzer @@ -29,25 +32,37 @@ def __init__( ) self.score_threshold = score_threshold self.language = language + self.ad_hoc_recognizers = ad_hoc_recognizers + self.context = context + self.allow_list = allow_list if not analyzer_engine: analyzer_engine = AnalyzerEngine() self._update_recognizers_based_on_entities_to_keep(analyzer_engine) self.analyzer_engine = analyzer_engine - def predict(self, sample: InputSample) -> List[str]: + def predict(self, sample: InputSample, **kwargs) -> List[str]: + language = kwargs.get("language", self.language) + score_threshold = kwargs.get("score_threshold", self.score_threshold) + ad_hoc_recognizers = kwargs.get("ad_hoc_recognizers", self.ad_hoc_recognizers) + context = kwargs.get("context", self.context) + allow_list = kwargs.get("allow_list", self.allow_list) results = self.analyzer_engine.analyze( text=sample.full_text, entities=self.entities, - language=self.language, - score_threshold=self.score_threshold, + language=language, + score_threshold=score_threshold, + ad_hoc_recognizers=ad_hoc_recognizers, + context=context, + allow_list=allow_list, + **kwargs, ) starts = [] ends = [] scores = [] tags = [] - # + for res in results: starts.append(res.start) ends.append(res.end)