Skip to content

Commit bc71da7

Browse files
Support for hipifying CuPy and add extra_extensions argument in hipify (ROCm#52)
* added optional extensions for hipification * add cupy special map with cupy related replacements * update cupy mappings * add support for custom mappings * update based on review comments and update readme * Update README.md --------- Co-authored-by: Jithun Nair <[email protected]>
1 parent 43e6fb5 commit bc71da7

File tree

2 files changed

+51
-1
lines changed

2 files changed

+51
-1
lines changed

README.md

+18
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ It can also "hipify" the header include statements in your source code to ensure
1111
- [Through python](#through-python)
1212
- [Utilities](#utilities)
1313
- [CMake utility function](#cmake-utility-function)
14+
- [Custom hipify mapping](#custom-hipify-mapping)
1415
- [Intended users](#intended-users)
1516

1617
<!-- tocstop -->
@@ -100,6 +101,23 @@ get_hipified_list("${TP_CUDA_SRCS}" TP_CUDA_SRCS)
100101
Here the `TP_CUDA_SRCS` in the input list containing cuda source files and doing a inplace update with output list `TP_CUDA_SRCS`
101102
For the file suffix unique string, list variable name itself is passed as a string.
102103

104+
# Custom hipify mapping
105+
106+
Users can define their own custom mapping by adding a custom_hipify_mapping.json from project_directory from where the hipify() function is being called.
107+
To use a JSON file from a different directory, users can pass in the JSON file path via ```custom_map``` argument in the hipify method.
108+
The custom hipify mappings will be applied *before* any other default hipify mappings.
109+
The below is the sample JSON file:
110+
111+
```
112+
{
113+
"custom_map" : {
114+
"src mapping 1" : "dst mapping 1",
115+
"src mapping 2" : "dst mapping 2",
116+
...
117+
}
118+
}
119+
```
120+
103121
# Intended users
104122

105123
This module can be used to

hipify_torch/hipify_python.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import shutil
3030
import sys
3131
import os
32+
import json
3233

3334
from . import constants
3435
from .cuda_to_hip_mappings import CUDA_TO_HIP_MAPPINGS
@@ -46,6 +47,8 @@
4647
to their actual types."""
4748
PYTORCH_TEMPLATE_MAP = {"Dtype": "scalar_t", "T": "scalar_t"}
4849

50+
# Custom mapping json file (default), if the file is not available hipify call doesn't process it.
51+
custom_mapping_file = "custom_hipify_mappings.json"
4952

5053
class InputError(Exception):
5154
# Exception raised for errors in the input.
@@ -682,7 +685,8 @@ def pattern(self):
682685
# When a file contains "sparse" in the filename, a mapping marked with API_SPARSE is preferred over other choices.
683686
# Similarly, "linalg" files require rocBLAS -> hipSOLVER so they also need special handling.
684687
PYTORCH_SPECIAL_MAP = {}
685-
688+
CUSTOM_TRIE = Trie()
689+
CUSTOM_SPECIAL_MAP = {}
686690

687691
for mapping in CUDA_TO_HIP_MAPPINGS:
688692
assert isinstance(mapping, Mapping)
@@ -708,6 +712,17 @@ def pattern(self):
708712
RE_THC_GENERIC_FILE = re.compile(r'#define THC_GENERIC_FILE "([^"]+)"')
709713
RE_CU_SUFFIX = re.compile(r'\.cu\b') # be careful not to pick up .cuh
710714

715+
# This function takes in the custom json file and looks for custom_mappings available.
716+
def update_custom_mappings():
717+
if os.path.exists(custom_mapping_file):
718+
with open(custom_mapping_file, 'r') as f:
719+
json_data = json.load(f)
720+
custom_map_dict = json_data['custom_map']
721+
for src in custom_map_dict.keys():
722+
CUSTOM_TRIE.add(src)
723+
dst = custom_map_dict[src]
724+
CUSTOM_SPECIAL_MAP[src] = dst
725+
711726
"""
712727
Returns a dict with the following keys:
713728
"hipified_path" : absolute path of hipified source file
@@ -754,6 +769,12 @@ def pt_special_repl(m):
754769
# checks SPARSE map first, and if a miss occurs, falls back to pytorch mappings
755770
return PYTORCH_SPECIAL_MAP.get(m.group(0), pt_repl(m))
756771

772+
# replace all custom mappings.
773+
if len(CUSTOM_SPECIAL_MAP) > 0:
774+
RE_CUSTOM_PREPROCESSOR = re.compile(CUSTOM_TRIE.pattern())
775+
def custom_repl(m):
776+
return CUSTOM_SPECIAL_MAP[m.group(0)]
777+
output_source = RE_CUSTOM_PREPROCESSOR.sub(custom_repl, output_source)
757778

758779
if is_pytorch_extension:
759780
output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source)
@@ -962,9 +983,11 @@ def hipify(
962983
show_detailed: bool = False,
963984
extensions: Iterable = (".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".in", ".hpp"),
964985
header_extensions: Iterable = (".cuh", ".h", ".hpp"),
986+
extra_extensions: Iterable = (),
965987
output_directory: str = "",
966988
header_include_dirs: Iterable = (),
967989
includes: Iterable = ('*',),
990+
custom_map_list: str = "",
968991
extra_files: Iterable = (),
969992
out_of_place_only: bool = False,
970993
ignores: Iterable = (),
@@ -982,6 +1005,12 @@ def hipify(
9821005
print("The project folder specified does not exist.")
9831006
sys.exit(1)
9841007

1008+
# custom mapping json file that is provided by user.
1009+
if custom_map_list:
1010+
global custom_mapping_file
1011+
custom_mapping_file = os.path.abspath(custom_map_list)
1012+
update_custom_mappings()
1013+
9851014
# If no output directory, provide a default one.
9861015
if not output_directory:
9871016
project_directory.rstrip("/")
@@ -991,6 +1020,9 @@ def hipify(
9911020
includes = [include.replace(project_directory, output_directory) for include in includes]
9921021
ignores = [ignore.replace(project_directory, output_directory) for ignore in ignores]
9931022

1023+
# update extensions with optional extensions.
1024+
extensions = extensions + extra_extensions
1025+
9941026
# Copy from project directory to output directory if not done already.
9951027
if not os.path.exists(output_directory):
9961028
shutil.copytree(project_directory, output_directory)

0 commit comments

Comments
 (0)