@@ -628,40 +628,66 @@ def is_caffe2_gpu_file(rel_filepath):
628
628
_ , ext = os .path .splitext (filename )
629
629
return ('gpu' in filename or ext in ['.cu' , '.cuh' ]) and ('cudnn' not in filename )
630
630
631
+ class TrieNode :
632
+ """A Trie node whose children are represented as a directory of char: TrieNode.
633
+ A special char '' represents end of word
634
+ """
635
+
636
+ def __init__ (self ):
637
+ self .children = {}
631
638
632
- # Cribbed from https://stackoverflow.com/questions/42742810/speed-up-millions-of-regex-replacements-in-python-3/42789508#42789508
633
- class Trie ():
634
- """Regex::Trie in Python. Creates a Trie out of a list of words. The trie can be exported to a Regex pattern.
639
+ class Trie :
640
+ """Creates a Trie out of a list of words. The trie can be exported to a Regex pattern.
635
641
The corresponding Regex should match much faster than a simple Regex union."""
636
642
637
643
def __init__ (self ):
638
- self .data = {}
644
+ """Initialize the trie with an empty root node."""
645
+ self .root = TrieNode ()
639
646
640
647
def add (self , word ):
641
- ref = self .data
648
+ """Add a word to the Trie. """
649
+ node = self .root
650
+
642
651
for char in word :
643
- ref [ char ] = char in ref and ref [ char ] or {}
644
- ref = ref [char ]
645
- ref ['' ] = 1
652
+ node . children . setdefault ( char , TrieNode ())
653
+ node = node . children [char ]
654
+ node . children ['' ] = True # Mark the end of the word
646
655
647
656
def dump (self ):
648
- return self .data
657
+ """Return the root node of Trie. """
658
+ return self .root
649
659
650
660
def quote (self , char ):
661
+ """ Escape a char for regex. """
651
662
return re .escape (char )
652
663
653
- def _pattern (self , pData ):
654
- data = pData
655
- if "" in data and len (data .keys ()) == 1 :
664
+ def search (self , word ):
665
+ """Search whether word is present in the Trie.
666
+ Returns True if yes, else return False"""
667
+ node = self .root
668
+ for char in word :
669
+ if char in node .children :
670
+ node = node .children [char ]
671
+ else :
672
+ return False
673
+
674
+ # make sure to check the end-of-word marker present
675
+ return '' in node .children
676
+
677
+ def _pattern (self , root ):
678
+ """Convert a Trie into a regular expression pattern"""
679
+ node = root
680
+
681
+ if "" in node .children and len (node .children .keys ()) == 1 :
656
682
return None
657
683
658
- alt = []
659
- cc = []
660
- q = 0
661
- for char in sorted (data .keys ()):
662
- if isinstance (data [char ], dict ):
684
+ alt = [] # store alternative patterns
685
+ cc = [] # store char to char classes
686
+ q = 0 # for node representing the end of word
687
+ for char in sorted (node . children .keys ()):
688
+ if isinstance (node . children [char ], TrieNode ):
663
689
try :
664
- recurse = self ._pattern (data [char ])
690
+ recurse = self ._pattern (node . children [char ])
665
691
alt .append (self .quote (char ) + recurse )
666
692
except Exception :
667
693
cc .append (self .quote (char ))
@@ -684,12 +710,16 @@ def _pattern(self, pData):
684
710
if cconly :
685
711
result += "?"
686
712
else :
687
- result = "(?:%s )?" % result
713
+ result = f "(?:{ result } )?"
688
714
return result
689
715
690
716
def pattern (self ):
691
- return self ._pattern (self .dump ())
717
+ """Export the Trie to a regex pattern."""
718
+ return self ._pattern (self .root )
692
719
720
+ def export_to_regex (self ):
721
+ """Export the Trie to a regex pattern."""
722
+ return self ._pattern (self .root )
693
723
694
724
CAFFE2_TRIE = Trie ()
695
725
CAFFE2_MAP = {}
@@ -724,8 +754,8 @@ def pattern(self):
724
754
if constants .API_PYTORCH not in meta_data :
725
755
CAFFE2_TRIE .add (src )
726
756
CAFFE2_MAP [src ] = dst
727
- RE_CAFFE2_PREPROCESSOR = re .compile (CAFFE2_TRIE .pattern ())
728
- RE_PYTORCH_PREPROCESSOR = re .compile (r'(?<=\W)({0})(?=\W)' .format (PYTORCH_TRIE .pattern ()))
757
+ RE_CAFFE2_PREPROCESSOR = re .compile (CAFFE2_TRIE .export_to_regex ())
758
+ RE_PYTORCH_PREPROCESSOR = re .compile (r'(?<=\W)({0})(?=\W)' .format (PYTORCH_TRIE .export_to_regex ()))
729
759
730
760
RE_QUOTE_HEADER = re .compile (r'#include "([^"]+)"' )
731
761
RE_ANGLE_HEADER = re .compile (r'#include <([^>]+)>' )
0 commit comments