Skip to content

Commit

Permalink
Update recognizer_registry.py
Browse files Browse the repository at this point in the history
  • Loading branch information
rgupta2508 authored Feb 5, 2025
1 parent a4109d3 commit 38a3d65
Showing 1 changed file with 20 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,6 @@ def get_recognizers(
all_possible_recognizers = copy.copy(self.recognizers)
if ad_hoc_recognizers:
all_possible_recognizers.extend(ad_hoc_recognizers)

# filter out unwanted recognizers
to_return = set()
if all_fields:
to_return = [
Expand All @@ -170,23 +168,26 @@ def get_recognizers(
if language == rec.supported_language
]
else:
entities = set(entities)
subset = [
rec
for rec in all_possible_recognizers
if bool(set(rec.supported_entities).intersection(entities))
and language == rec.supported_language
]
if not subset:
logger.warning(
"Entity %s doesn't have the corresponding"
" recognizer in language : %s",
entity,
language,
)
else:
to_return.update(set(subset))

# filter out unwanted recognizers
all_entity_recognizers = dict()
for rec in all_possible_recognizers:
if type(rec.supported_entities) == list and len(rec.supported_entities) > 0:
for supported_entity in rec.supported_entities:
all_entity_recognizers[supported_entity] = all_entity_recognizers[supported_entity].add(
rec) if supported_entity in all_entity_recognizers else {rec}
elif type(rec.supported_entities) == str:
all_entity_recognizers[rec.supported_entities] = all_entity_recognizers[rec.supported_entities].add(
rec) if rec.supported_entities in all_entity_recognizers else {rec}
for entity in entities:
if entity in all_entity_recognizers:
to_return.update(all_entity_recognizers[entity])
else:
logger.warning(
"Entity %s doesn't have the corresponding"
" recognizer in language : %s",
entity,
language,
)
logger.debug(
"Returning a total of %s recognizers",
str(len(to_return)),
Expand Down

0 comments on commit 38a3d65

Please sign in to comment.