Skip to content

Commit 6e0ca4b

Browse files
Palamida scan fix (#68)
* Changed copyright statements to meet AMD requirements. * Re-engineering Hipify Trie: (1) Re-engineering Trie. (2) More documentation or comments for easier understanding Ported from: pytorch/pytorch#118433
1 parent 479b234 commit 6e0ca4b

File tree

2 files changed

+53
-23
lines changed

2 files changed

+53
-23
lines changed

Diff for: LICENSE.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
MIT License
22

3-
Copyright (c) 2017 AMD Compute Libraries
3+
Copyright (c) 2021-2024, Advanced Micro Devices, Inc. All rights reserved.
44

55
Permission is hereby granted, free of charge, to any person obtaining a copy
66
of this software and associated documentation files (the "Software"), to deal

Diff for: hipify_torch/hipify_python.py

+52-22
Original file line numberDiff line numberDiff line change
@@ -628,40 +628,66 @@ def is_caffe2_gpu_file(rel_filepath):
628628
_, ext = os.path.splitext(filename)
629629
return ('gpu' in filename or ext in ['.cu', '.cuh']) and ('cudnn' not in filename)
630630

631+
class TrieNode:
632+
"""A Trie node whose children are represented as a directory of char: TrieNode.
633+
A special char '' represents end of word
634+
"""
635+
636+
def __init__(self):
637+
self.children = {}
631638

632-
# Cribbed from https://stackoverflow.com/questions/42742810/speed-up-millions-of-regex-replacements-in-python-3/42789508#42789508
633-
class Trie():
634-
"""Regex::Trie in Python. Creates a Trie out of a list of words. The trie can be exported to a Regex pattern.
639+
class Trie:
640+
"""Creates a Trie out of a list of words. The trie can be exported to a Regex pattern.
635641
The corresponding Regex should match much faster than a simple Regex union."""
636642

637643
def __init__(self):
638-
self.data = {}
644+
"""Initialize the trie with an empty root node."""
645+
self.root = TrieNode()
639646

640647
def add(self, word):
641-
ref = self.data
648+
"""Add a word to the Trie. """
649+
node = self.root
650+
642651
for char in word:
643-
ref[char] = char in ref and ref[char] or {}
644-
ref = ref[char]
645-
ref[''] = 1
652+
node.children.setdefault(char, TrieNode())
653+
node = node.children[char]
654+
node.children[''] = True # Mark the end of the word
646655

647656
def dump(self):
648-
return self.data
657+
"""Return the root node of Trie. """
658+
return self.root
649659

650660
def quote(self, char):
661+
""" Escape a char for regex. """
651662
return re.escape(char)
652663

653-
def _pattern(self, pData):
654-
data = pData
655-
if "" in data and len(data.keys()) == 1:
664+
def search(self, word):
665+
"""Search whether word is present in the Trie.
666+
Returns True if yes, else return False"""
667+
node = self.root
668+
for char in word:
669+
if char in node.children:
670+
node = node.children[char]
671+
else:
672+
return False
673+
674+
# make sure to check the end-of-word marker present
675+
return '' in node.children
676+
677+
def _pattern(self, root):
678+
"""Convert a Trie into a regular expression pattern"""
679+
node = root
680+
681+
if "" in node.children and len(node.children.keys()) == 1:
656682
return None
657683

658-
alt = []
659-
cc = []
660-
q = 0
661-
for char in sorted(data.keys()):
662-
if isinstance(data[char], dict):
684+
alt = [] # store alternative patterns
685+
cc = [] # store char to char classes
686+
q = 0 # for node representing the end of word
687+
for char in sorted(node.children.keys()):
688+
if isinstance(node.children[char], TrieNode):
663689
try:
664-
recurse = self._pattern(data[char])
690+
recurse = self._pattern(node.children[char])
665691
alt.append(self.quote(char) + recurse)
666692
except Exception:
667693
cc.append(self.quote(char))
@@ -684,12 +710,16 @@ def _pattern(self, pData):
684710
if cconly:
685711
result += "?"
686712
else:
687-
result = "(?:%s)?" % result
713+
result = f"(?:{result})?"
688714
return result
689715

690716
def pattern(self):
691-
return self._pattern(self.dump())
717+
"""Export the Trie to a regex pattern."""
718+
return self._pattern(self.root)
692719

720+
def export_to_regex(self):
721+
"""Export the Trie to a regex pattern."""
722+
return self._pattern(self.root)
693723

694724
CAFFE2_TRIE = Trie()
695725
CAFFE2_MAP = {}
@@ -724,8 +754,8 @@ def pattern(self):
724754
if constants.API_PYTORCH not in meta_data:
725755
CAFFE2_TRIE.add(src)
726756
CAFFE2_MAP[src] = dst
727-
RE_CAFFE2_PREPROCESSOR = re.compile(CAFFE2_TRIE.pattern())
728-
RE_PYTORCH_PREPROCESSOR = re.compile(r'(?<=\W)({0})(?=\W)'.format(PYTORCH_TRIE.pattern()))
757+
RE_CAFFE2_PREPROCESSOR = re.compile(CAFFE2_TRIE.export_to_regex())
758+
RE_PYTORCH_PREPROCESSOR = re.compile(r'(?<=\W)({0})(?=\W)'.format(PYTORCH_TRIE.export_to_regex()))
729759

730760
RE_QUOTE_HEADER = re.compile(r'#include "([^"]+)"')
731761
RE_ANGLE_HEADER = re.compile(r'#include <([^>]+)>')

0 commit comments

Comments
 (0)