diff --git a/3rdparty/composable_kernel b/3rdparty/composable_kernel index d49abdaa..1138d7d8 160000 --- a/3rdparty/composable_kernel +++ b/3rdparty/composable_kernel @@ -1 +1 @@ -Subproject commit d49abdaa8711cc3f690f8ffe00f7393b2708a28f +Subproject commit 1138d7d8fbc126b635871b3a0283a062ae69f8c3 diff --git a/aiter/__init__.py b/aiter/__init__.py index 03ded4ee..b63c14dc 100644 --- a/aiter/__init__.py +++ b/aiter/__init__.py @@ -4,14 +4,40 @@ import torch import os import logging + + logger = logging.getLogger("aiter") + + +def getLogger(): + global logger + if not logger.handlers: + logger.setLevel(logging.DEBUG) + + console_handler = logging.StreamHandler() + if int(os.environ.get("AITER_LOG_MORE", 0)): + formatter = logging.Formatter( + fmt="[%(name)s %(levelname)s] %(asctime)s.%(msecs)03d - %(processName)s:%(process)d - %(pathname)s:%(lineno)d - %(funcName)s\n%(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + else: + formatter = logging.Formatter( + fmt="[%(name)s] %(message)s", + ) + console_handler.setFormatter(formatter) + console_handler.setLevel(logging.INFO) + logger.addHandler(console_handler) + + return logger + + +logger = getLogger() + import importlib.util -if importlib.util.find_spec('aiter_') is not None: + +if importlib.util.find_spec("aiter_") is not None: from aiter_ import * -# if importlib.util.find_spec('hipbsolidxgemm_') is not None: -# from hipbsolidxgemm_ import * -# if importlib.util.find_spec('rocsolidxgemm_') is not None: -# from rocsolidxgemm_ import * +from .jit import core from .ops.norm import * from .ops.quant import * from .ops.gemm_op_a8w8 import * @@ -33,23 +59,3 @@ from .ops.mha import * from .ops.gradlib import * from . import mla - -def getLogger(): - global logger - if not logger.handlers: - logger.setLevel(logging.DEBUG) - - console_handler = logging.StreamHandler() - if int(os.environ.get('AITER_LOG_MORE', 0)): - formatter = logging.Formatter( - fmt="[%(name)s %(levelname)s] %(asctime)s.%(msecs)03d - %(process)d:%(processName)s - %(pathname)s:%(lineno)d - %(funcName)s\n%(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - ) - console_handler.setFormatter(formatter) - console_handler.setLevel(logging.INFO) - logger.addHandler(console_handler) - - return logger - - -logger = getLogger() diff --git a/aiter/configs/tuned_fmoe.csv b/aiter/configs/tuned_fmoe.csv new file mode 100644 index 00000000..ef184982 --- /dev/null +++ b/aiter/configs/tuned_fmoe.csv @@ -0,0 +1,18 @@ +token,model_dim,inter_dim,expert,topk,dtype,q_dtype,q_type,use_g1u1,us,tag,err +1,6144,4096,8,2,torch.bfloat16,torch.float8_e4m3fnuz,QuantType.per_Token,1,40.3231600000001,fmoe_stage1_bf16_pertokenFp8_g1u1_32x128,0.0% +2,6144,4096,8,2,torch.bfloat16,torch.float8_e4m3fnuz,QuantType.per_Token,1,63.736449999999934,fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_2tg_pf3,0.1% +4,6144,4096,8,2,torch.bfloat16,torch.float8_e4m3fnuz,QuantType.per_Token,1,71.99169000000012,fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_2tg_pf3,0.0% +8,6144,4096,8,2,torch.bfloat16,torch.float8_e4m3fnuz,QuantType.per_Token,1,104.50463000000003,fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_3tg_pf2,0.1% +16,6144,4096,8,2,torch.bfloat16,torch.float8_e4m3fnuz,QuantType.per_Token,1,128.4107499999999,fmoe_stage1_bf16_pertokenFp8_g1u1_16x512_pf2,0.1% +32,6144,4096,8,2,torch.bfloat16,torch.float8_e4m3fnuz,QuantType.per_Token,1,130.71911999999998,fmoe_stage1_bf16_pertokenFp8_g1u1_16x512_pf2,0.1% +64,6144,4096,8,2,torch.bfloat16,torch.float8_e4m3fnuz,QuantType.per_Token,1,137.25394999999983,fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_2tg_pf3,0.1% +128,6144,4096,8,2,torch.bfloat16,torch.float8_e4m3fnuz,QuantType.per_Token,1,164.59966000000009,fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_2tg_pf3,0.1% +256,6144,4096,8,2,torch.bfloat16,torch.float8_e4m3fnuz,QuantType.per_Token,1,237.0920099999999,fmoe_stage1_bf16_pertokenFp8_g1u1_32x512_pf2,0.1% +512,6144,4096,8,2,torch.bfloat16,torch.float8_e4m3fnuz,QuantType.per_Token,1,410.4526800000004,fmoe_stage1_bf16_pertokenFp8_g1u1_48x128,0.1% +1024,6144,4096,8,2,torch.bfloat16,torch.float8_e4m3fnuz,QuantType.per_Token,1,670.9813600000002,fmoe_stage1_bf16_pertokenFp8_g1u1_128x128_pf2,0.1% +1536,6144,4096,8,2,torch.bfloat16,torch.float8_e4m3fnuz,QuantType.per_Token,1,1002.8835099999991,fmoe_stage1_bf16_pertokenFp8_g1u1_128x128_pf2,0.1% +2048,6144,4096,8,2,torch.bfloat16,torch.float8_e4m3fnuz,QuantType.per_Token,1,1251.7141599999998,fmoe_stage1_bf16_pertokenFp8_g1u1_128x128,0.1% +3072,6144,4096,8,2,torch.bfloat16,torch.float8_e4m3fnuz,QuantType.per_Token,1,1780.3102,fmoe_stage1_bf16_pertokenFp8_g1u1_160x128_pf2,0.1% +4096,6144,4096,8,2,torch.bfloat16,torch.float8_e4m3fnuz,QuantType.per_Token,1,2329.7178200000008,fmoe_stage1_bf16_pertokenFp8_g1u1_160x128_pf2,0.1% +6144,6144,4096,8,2,torch.bfloat16,torch.float8_e4m3fnuz,QuantType.per_Token,1,3355.0757600000006,fmoe_stage1_bf16_pertokenFp8_g1u1_128x128_pf2,0.1% +8192,6144,4096,8,2,torch.bfloat16,torch.float8_e4m3fnuz,QuantType.per_Token,1,4447.738730000001,fmoe_stage1_bf16_pertokenFp8_g1u1_128x128_pf2,0.1% diff --git a/aiter/configs/untuned_fmoe.csv b/aiter/configs/untuned_fmoe.csv new file mode 100644 index 00000000..f434e59a --- /dev/null +++ b/aiter/configs/untuned_fmoe.csv @@ -0,0 +1,18 @@ +token,model_dim,inter_dim,expert,topk,dtype,q_dtype,q_type,use_g1u1 +1,6144,4096,8,2,torch.bfloat16,torch.float8_e4m3fnuz,QuantType.per_Token,1 +2,6144,4096,8,2,torch.bfloat16,torch.float8_e4m3fnuz,QuantType.per_Token,1 +4,6144,4096,8,2,torch.bfloat16,torch.float8_e4m3fnuz,QuantType.per_Token,1 +8,6144,4096,8,2,torch.bfloat16,torch.float8_e4m3fnuz,QuantType.per_Token,1 +16,6144,4096,8,2,torch.bfloat16,torch.float8_e4m3fnuz,QuantType.per_Token,1 +32,6144,4096,8,2,torch.bfloat16,torch.float8_e4m3fnuz,QuantType.per_Token,1 +64,6144,4096,8,2,torch.bfloat16,torch.float8_e4m3fnuz,QuantType.per_Token,1 +128,6144,4096,8,2,torch.bfloat16,torch.float8_e4m3fnuz,QuantType.per_Token,1 +256,6144,4096,8,2,torch.bfloat16,torch.float8_e4m3fnuz,QuantType.per_Token,1 +512,6144,4096,8,2,torch.bfloat16,torch.float8_e4m3fnuz,QuantType.per_Token,1 +1024,6144,4096,8,2,torch.bfloat16,torch.float8_e4m3fnuz,QuantType.per_Token,1 +1536,6144,4096,8,2,torch.bfloat16,torch.float8_e4m3fnuz,QuantType.per_Token,1 +2048,6144,4096,8,2,torch.bfloat16,torch.float8_e4m3fnuz,QuantType.per_Token,1 +3072,6144,4096,8,2,torch.bfloat16,torch.float8_e4m3fnuz,QuantType.per_Token,1 +4096,6144,4096,8,2,torch.bfloat16,torch.float8_e4m3fnuz,QuantType.per_Token,1 +6144,6144,4096,8,2,torch.bfloat16,torch.float8_e4m3fnuz,QuantType.per_Token,1 +8192,6144,4096,8,2,torch.bfloat16,torch.float8_e4m3fnuz,QuantType.per_Token,1 diff --git a/aiter/jit/core.py b/aiter/jit/core.py index ae0303d0..38ccbf82 100644 --- a/aiter/jit/core.py +++ b/aiter/jit/core.py @@ -18,8 +18,8 @@ from packaging.version import parse, Version PREBUILD_KERNELS = False -if os.path.exists(os.path.dirname(os.path.abspath(__file__))+"/aiter_.so"): - aiter_ = importlib.import_module(f'{__package__}.aiter_') +if os.path.exists(os.path.dirname(os.path.abspath(__file__)) + "/aiter_.so"): + aiter_ = importlib.import_module(f"{__package__}.aiter_") PREBUILD_KERNELS = True logger = logging.getLogger("aiter") @@ -37,6 +37,7 @@ package_path = find_aiter.origin package_path = os.path.dirname(package_path) import site + site_packages_dirs = site.getsitepackages() # develop mode if package_path not in site_packages_dirs: @@ -47,17 +48,16 @@ else: print("aiter is not installed.") -AITER_CSRC_DIR = f'{AITER_ROOT_DIR}/csrc' -AITER_GRADLIB_DIR = f'{AITER_ROOT_DIR}/gradlib' -os.environ["AITER_ASM_DIR"] = f'{AITER_ROOT_DIR}/hsa/' -CK_DIR = os.environ.get("CK_DIR", - f"{AITER_ROOT_DIR}/3rdparty/composable_kernel") +AITER_CSRC_DIR = f"{AITER_ROOT_DIR}/csrc" +AITER_GRADLIB_DIR = f"{AITER_ROOT_DIR}/gradlib" +os.environ["AITER_ASM_DIR"] = f"{AITER_ROOT_DIR}/hsa/" +CK_DIR = os.environ.get("CK_DIR", f"{AITER_ROOT_DIR}/3rdparty/composable_kernel") @functools.lru_cache(maxsize=None) def get_user_jit_dir(): - if 'JIT_WORKSPACE_DIR' in os.environ: - path = os.getenv('JIT_WORKSPACE_DIR') + if "JIT_WORKSPACE_DIR" in os.environ: + path = os.getenv("JIT_WORKSPACE_DIR") os.makedirs(path, exist_ok=True) return path else: @@ -69,20 +69,19 @@ def get_user_jit_dir(): return home_jit_dir -bd_dir = f'{get_user_jit_dir()}/build' +bd_dir = f"{get_user_jit_dir()}/build" # copy ck to build, thus hippify under bd_dir -if multiprocessing.current_process().name == 'MainProcess': - shutil.copytree(CK_DIR, f'{bd_dir}/ck', dirs_exist_ok=True) - if os.path.exists(f'{bd_dir}/ck/library'): - shutil.rmtree(f'{bd_dir}/ck/library') -CK_DIR = f'{bd_dir}/ck' +if multiprocessing.current_process().name == "MainProcess": + shutil.copytree(CK_DIR, f"{bd_dir}/ck", dirs_exist_ok=True) + if os.path.exists(f"{bd_dir}/ck/library"): + shutil.rmtree(f"{bd_dir}/ck/library") +CK_DIR = f"{bd_dir}/ck" def validate_and_update_archs(): archs = os.getenv("GPU_ARCHS", "native").split(";") # List of allowed architectures - allowed_archs = ["native", "gfx90a", - "gfx940", "gfx941", "gfx942", "gfx1100"] + allowed_archs = ["native", "gfx90a", "gfx940", "gfx941", "gfx942", "gfx1100"] # Validate if each element in archs is in allowed_archs assert all( @@ -92,12 +91,14 @@ def validate_and_update_archs(): def check_and_set_ninja_worker(): - max_num_jobs_cores = int(max(1, os.cpu_count()*0.8)) - if int(os.environ.get("MAX_JOBS", '1')) < max_num_jobs_cores: + max_num_jobs_cores = int(max(1, os.cpu_count() * 0.8)) + if int(os.environ.get("MAX_JOBS", "1")) < max_num_jobs_cores: import psutil + # calculate the maximum allowed NUM_JOBS based on free memory - free_memory_gb = psutil.virtual_memory().available / \ - (1024 ** 3) # free memory in GB + free_memory_gb = psutil.virtual_memory().available / ( + 1024**3 + ) # free memory in GB # each JOB peak memory cost is ~8-9GB when threads = 4 max_num_jobs_memory = int(free_memory_gb / 9) @@ -112,53 +113,72 @@ def do_rename_and_mv(name, src, dst, ret): newName = name if name.endswith(".cpp") or name.endswith(".cu"): newName = name.replace(".cpp", ".cu") - ret.append(f'{dst}/{newName}') - shutil.copy(f'{src}/{name}', f'{dst}/{newName}') + ret.append(f"{dst}/{newName}") + shutil.copy(f"{src}/{name}", f"{dst}/{newName}") + ret = [] for el in els: if not os.path.exists(el): - logger.warning(f'---> {el} not exists!!!!!!') + logger.warning(f"---> {el} not exists!!!!!!") continue if os.path.isdir(el): for entry in os.listdir(el): - if os.path.isdir(f'{el}/{entry}'): + if os.path.isdir(f"{el}/{entry}"): if recurisve: - ret += rename_cpp_to_cu([f'{el}/{entry}'], - dst, recurisve) + ret += rename_cpp_to_cu([f"{el}/{entry}"], dst, recurisve) continue do_rename_and_mv(entry, el, dst, ret) else: - do_rename_and_mv(os.path.basename(el), - os.path.dirname(el), dst, ret) + do_rename_and_mv(os.path.basename(el), os.path.dirname(el), dst, ret) return ret def get_hip_version(): - return parse(torch.version.hip.split()[-1].rstrip('-').replace('-', '+')) + return parse(torch.version.hip.split()[-1].rstrip("-").replace("-", "+")) -@functools.lru_cache(maxsize=1024) -def get_module(md_name): - numa_balance_set = os.popen( - "cat /proc/sys/kernel/numa_balancing").read().strip() +@functools.lru_cache() +def check_numa(): + numa_balance_set = os.popen("cat /proc/sys/kernel/numa_balancing").read().strip() if numa_balance_set == "1": - logger.warning("WARNING: NUMA balancing is enabled, which may cause errors. " - "It is recommended to disable NUMA balancing by running 'sudo sh -c echo 0 > /proc/sys/kernel/numa_balancing' " - "for more details: https://rocm.docs.amd.com/en/latest/how-to/system-optimization/mi300x.html#disable-numa-auto-balancing") - return importlib.import_module(f'{__package__}.{md_name}') + logger.warning( + "WARNING: NUMA balancing is enabled, which may cause errors. " + "It is recommended to disable NUMA balancing by running 'sudo sh -c echo 0 > /proc/sys/kernel/numa_balancing' " + "for more details: https://rocm.docs.amd.com/en/latest/how-to/system-optimization/mi300x.html#disable-numa-auto-balancing" + ) + +__mds = {} -def build_module(md_name, srcs, flags_extra_cc, flags_extra_hip, blob_gen_cmd, extra_include, extra_ldflags, verbose): + +@functools.lru_cache(maxsize=1024) +def get_module(md_name): + check_numa() + if md_name not in __mds: + __mds[md_name] = importlib.import_module(f"{__package__}.{md_name}") + return __mds[md_name] + + +def build_module( + md_name, + srcs, + flags_extra_cc, + flags_extra_hip, + blob_gen_cmd, + extra_include, + extra_ldflags, + verbose, +): startTS = time.perf_counter() try: - op_dir = f'{bd_dir}/{md_name}' - logger.info(f'start build [{md_name}] under {op_dir}') + op_dir = f"{bd_dir}/{md_name}" + logger.info(f"start build [{md_name}] under {op_dir}") - opbd_dir = f'{op_dir}/build' - src_dir = f'{op_dir}/build/srcs' + opbd_dir = f"{op_dir}/build" + src_dir = f"{op_dir}/build/srcs" os.makedirs(src_dir, exist_ok=True) - if os.path.exists(f'{get_user_jit_dir()}/{md_name}.so'): - os.remove(f'{get_user_jit_dir()}/{md_name}.so') + if os.path.exists(f"{get_user_jit_dir()}/{md_name}.so"): + os.remove(f"{get_user_jit_dir()}/{md_name}.so") sources = rename_cpp_to_cu(srcs, src_dir) @@ -170,8 +190,8 @@ def build_module(md_name, srcs, flags_extra_cc, flags_extra_hip, blob_gen_cmd, e "-D__HIP_PLATFORM_AMD__=1", "-U__HIP_NO_HALF_CONVERSIONS__", "-U__HIP_NO_HALF_OPERATORS__", - - "-mllvm", "--amdgpu-kernarg-preload-count=16", + "-mllvm", + "--amdgpu-kernarg-preload-count=16", # "-v", "--save-temps", "-Wno-unused-result", "-Wno-switch-bool", @@ -183,14 +203,18 @@ def build_module(md_name, srcs, flags_extra_cc, flags_extra_hip, blob_gen_cmd, e # Imitate https://github.com/ROCm/composable_kernel/blob/c8b6b64240e840a7decf76dfaa13c37da5294c4a/CMakeLists.txt#L190-L214 hip_version = get_hip_version() - if hip_version > Version('5.7.23302'): + if hip_version > Version("5.7.23302"): flags_hip += ["-fno-offload-uniform-block"] - if hip_version > Version('6.1.40090'): + if hip_version > Version("6.1.40090"): flags_hip += ["-mllvm", "-enable-post-misched=0"] - if hip_version > Version('6.2.41132'): - flags_hip += ["-mllvm", "-amdgpu-early-inline-all=true", - "-mllvm", "-amdgpu-function-calls=false"] - if hip_version > Version('6.2.41133'): + if hip_version > Version("6.2.41132"): + flags_hip += [ + "-mllvm", + "-amdgpu-early-inline-all=true", + "-mllvm", + "-amdgpu-function-calls=false", + ] + if hip_version > Version("6.2.41133"): flags_hip += ["-mllvm", "-amdgpu-coerce-illegal-types=1"] flags_cc += flags_extra_cc @@ -203,19 +227,19 @@ def exec_blob(blob_gen_cmd, op_dir, src_dir, sources): if blob_gen_cmd: blob_dir = f"{op_dir}/blob" os.makedirs(blob_dir, exist_ok=True) - baton = FileBaton(os.path.join(blob_dir, 'lock')) + baton = FileBaton(os.path.join(blob_dir, "lock")) if baton.try_acquire(): try: if AITER_LOG_MORE: logger.info( - f'exec_blob ---> {PY} {blob_gen_cmd.format(blob_dir)}') - os.system(f'{PY} {blob_gen_cmd.format(blob_dir)}') + f"exec_blob ---> {PY} {blob_gen_cmd.format(blob_dir)}" + ) + os.system(f"{PY} {blob_gen_cmd.format(blob_dir)}") finally: baton.release() else: baton.wait() - sources += rename_cpp_to_cu([blob_dir], - src_dir, recurisve=True) + sources += rename_cpp_to_cu([blob_dir], src_dir, recurisve=True) return sources if isinstance(blob_gen_cmd, list): @@ -224,10 +248,9 @@ def exec_blob(blob_gen_cmd, op_dir, src_dir, sources): else: sources = exec_blob(blob_gen_cmd, op_dir, src_dir, sources) - bd_include_dir = f'{op_dir}/build/include' + bd_include_dir = f"{op_dir}/build/include" os.makedirs(bd_include_dir, exist_ok=True) - rename_cpp_to_cu([f"{AITER_CSRC_DIR}/include"] + extra_include, - bd_include_dir) + rename_cpp_to_cu([f"{AITER_CSRC_DIR}/include"] + extra_include, bd_include_dir) extra_include_paths = [ f"{CK_DIR}/include", f"{CK_DIR}/library/include", @@ -246,28 +269,31 @@ def exec_blob(blob_gen_cmd, op_dir, src_dir, sources): with_cuda=True, is_python_module=True, ) - shutil.copy(f'{opbd_dir}/{md_name}.so', f'{get_user_jit_dir()}') + shutil.copy(f"{opbd_dir}/{md_name}.so", f"{get_user_jit_dir()}") except Exception as e: - logger.error('failed build jit [{}]\n-->[History]: {}'.format( - md_name, - '-->'.join(traceback.format_exception(*sys.exc_info())) - )) + logger.error( + "failed build jit [{}]\n-->[History]: {}".format( + md_name, "-->".join(traceback.format_exception(*sys.exc_info())) + ) + ) raise Exception(f"failed build jit [{md_name}]...") - logger.info( - f'finish build [{md_name}], cost {time.perf_counter()-startTS:.8f}s') + logger.info(f"finish build [{md_name}], cost {time.perf_counter()-startTS:.8f}s") + if md_name not in __mds: + __mds[md_name] = module return module def get_args_of_build(ops_name: str, exclue=[]): - d_opt_build_args = {"srcs": [], - "md_name": "", - "flags_extra_cc": [], - "flags_extra_hip": [], - "extra_ldflags": None, - "extra_include": [], - "verbose": False, - "blob_gen_cmd": "" - } + d_opt_build_args = { + "srcs": [], + "md_name": "", + "flags_extra_cc": [], + "flags_extra_hip": [], + "extra_ldflags": None, + "extra_include": [], + "verbose": False, + "blob_gen_cmd": "", + } def convert(d_ops: dict): # judge isASM @@ -286,16 +312,19 @@ def convert(d_ops: dict): else: pass return d_ops - with open(this_dir+"/optCompilerConfig.json", 'r') as file: + + with open(this_dir + "/optCompilerConfig.json", "r") as file: data = json.load(file) if isinstance(data, dict): # parse all ops if ops_name == "all": - d_all_ops = {"srcs": [], - "flags_extra_cc": [], - "flags_extra_hip": [], - "extra_include": [], - "blob_gen_cmd": []} + d_all_ops = { + "srcs": [], + "flags_extra_cc": [], + "flags_extra_hip": [], + "extra_include": [], + "blob_gen_cmd": [], + } # traverse opts for ops_name, d_ops in data.items(): # Cannot contain tune ops @@ -308,15 +337,13 @@ def convert(d_ops: dict): for k in d_all_ops.keys(): if isinstance(single_ops[k], list): d_all_ops[k] += single_ops[k] - elif isinstance(single_ops[k], str) and single_ops[k] != '': + elif isinstance(single_ops[k], str) and single_ops[k] != "": d_all_ops[k].append(single_ops[k]) - # print(d_all_ops) return d_all_ops # no find opt_name in json. elif data.get(ops_name) == None: - logger.warning( - "Not found this operator in 'optCompilerConfig.json'. ") + logger.warning("Not found this operator in 'optCompilerConfig.json'. ") return d_opt_build_args # parser single opt else: @@ -324,7 +351,8 @@ def convert(d_ops: dict): return convert(compile_ops_) else: logger.warning( - "ERROR: pls use dict_format to write 'optCompilerConfig.json'! ") + "ERROR: pls use dict_format to write 'optCompilerConfig.json'! " + ) def compile_ops(_md_name: str, fc_name: Optional[str] = None): @@ -340,14 +368,14 @@ def wrapper(*args, custom_build_args={}, **kwargs): if hasattr(aiter_, loadName): module = aiter_ if module is None: - module = get_module(custom_build_args.get('md_name', - md_name)) - except Exception as e: + md = custom_build_args.get("md_name", md_name) + module = get_module(md) + except ModuleNotFoundError as e: d_args = get_args_of_build(md_name) d_args.update(custom_build_args) # update module if we have coustom build - md_name = custom_build_args.get('md_name', md_name) + md_name = custom_build_args.get("md_name", md_name) srcs = d_args["srcs"] flags_extra_cc = d_args["flags_extra_cc"] @@ -356,14 +384,25 @@ def wrapper(*args, custom_build_args={}, **kwargs): extra_include = d_args["extra_include"] extra_ldflags = d_args["extra_ldflags"] verbose = d_args["verbose"] - module = build_module(md_name, srcs, flags_extra_cc, flags_extra_hip, - blob_gen_cmd, extra_include, extra_ldflags, verbose) + module = build_module( + md_name, + srcs, + flags_extra_cc, + flags_extra_hip, + blob_gen_cmd, + extra_include, + extra_ldflags, + verbose, + ) op = getattr(module, loadName) if AITER_LOG_MORE == 2: from ..test_common import log_args + log_args(func, *args, **kwargs) return op(*args, **kwargs) + return wrapper + return decorator diff --git a/aiter/jit/optCompilerConfig.json b/aiter/jit/optCompilerConfig.json index 7247a41e..7c669254 100644 --- a/aiter/jit/optCompilerConfig.json +++ b/aiter/jit/optCompilerConfig.json @@ -197,7 +197,8 @@ "f'{AITER_CSRC_DIR}/kernels/topk_softmax_kernels.cu'", "f'{AITER_CSRC_DIR}/kernels/topk_softmax_kernels_group.cu'", "f'{AITER_CSRC_DIR}/kernels/moe_align_block_size_kernels.cu'", - "f'{AITER_CSRC_DIR}/py_itfs_cu/asm_fmoe.cpp'" + "f'{AITER_CSRC_DIR}/py_itfs_cu/asm_fmoe.cpp'", + "f'{AITER_CSRC_DIR}/py_itfs_cu/asm_moe_2stage.cpp'" ], "flags_extra_cc": [], "flags_extra_hip": [], diff --git a/aiter/ops/moe_op.py b/aiter/ops/moe_op.py index adce5ea2..1ad1c2e5 100644 --- a/aiter/ops/moe_op.py +++ b/aiter/ops/moe_op.py @@ -3,10 +3,23 @@ from torch import Tensor from typing import List, Optional -from ..jit.core import compile_ops, CK_DIR, AITER_CSRC_DIR, AITER_ROOT_DIR, AITER_CORE_DIR +from ..jit.core import ( + compile_ops, + CK_DIR, + AITER_CSRC_DIR, + AITER_ROOT_DIR, + AITER_CORE_DIR, +) import torch.nn.functional as F +@compile_ops("module_moe_asm", "ActivationType") +def _ActivationType(dummy): ... + + +ActivationType = _ActivationType(0) + + @compile_ops("module_moe_asm") def topk_softmax( topk_weights: Tensor, @@ -47,14 +60,6 @@ def fmoe( ): ... -@compile_ops("module_moe_asm", fc_name='ActivationType') -class _ActivationType(): - ... - - -ActivationType = _ActivationType(0) - - @compile_ops("module_moe_asm") def fmoe_int8_g1u0( out: Tensor, @@ -149,6 +154,24 @@ def fmoe_fp8_blockscale_g1u1( ): ... +@compile_ops("module_moe_asm") +def moe_stage1_fp8_g1u1( + out: Tensor, + input: Tensor, + gate: Tensor, + down: Tensor, + sorted_token_ids: Tensor, + sorted_weight_buf: Tensor, + sorted_expert_ids: Tensor, + num_valid_ids: Tensor, + topk: int, + fc1_scale: Tensor, + fc2_scale: Optional[Tensor] = None, + input_scale: Optional[Tensor] = None, + fc2_smooth_scale: Optional[Tensor] = None, +): ... + + @compile_ops("module_moe") def ck_moe( hidden_states: Tensor, @@ -161,7 +184,7 @@ def ck_moe( fc1_smooth_scale: Optional[Tensor] = None, fc2_smooth_scale: Optional[Tensor] = None, block_m: Optional[int] = 32, - expert_mask: Optional[Tensor] = None + expert_mask: Optional[Tensor] = None, ): ... @@ -177,7 +200,7 @@ def ck_moe_stage1( topk: int, w1_scale: Optional[Tensor] = None, a1_scale: Optional[Tensor] = None, - block_m: Optional[int] = 32 + block_m: Optional[int] = 32, ): ... @@ -194,5 +217,5 @@ def ck_moe_stage2( topk: int, w2_scale: Optional[Tensor] = None, a2_scale: Optional[Tensor] = None, - block_m: Optional[int] = 32 + block_m: Optional[int] = 32, ): ... diff --git a/aiter/ops/quant.py b/aiter/ops/quant.py index 63101d27..ba0f9bd6 100644 --- a/aiter/ops/quant.py +++ b/aiter/ops/quant.py @@ -4,12 +4,12 @@ import torch from torch import Tensor from typing import List, Optional +from enum import Enum from ..jit.core import compile_ops, CK_DIR, AITER_CSRC_DIR @compile_ops("module_smoothquant") -def smoothquant_fwd(input: Tensor, out: Tensor, - x_scale: Tensor, y_scale: Tensor): ... +def smoothquant_fwd(input: Tensor, out: Tensor, x_scale: Tensor, y_scale: Tensor): ... @compile_ops("module_smoothquant") @@ -27,18 +27,17 @@ def get_dtype_max(dtype): return dtypeMax -def pertoken_quant(x, y_scale_dtype=torch.float, x_scale=None, quant_dtype=torch.int8, dtypeMax=None): +def pertoken_quant( + x, x_scale=None, scale_dtype=torch.float, quant_dtype=torch.int8, dtypeMax=None +): + x = x.to(torch.float) if x_scale is None: hidden_states = x else: # smooth quant hidden_states = x * x_scale # [m, 1] - per_token_amax, _ = torch.max( - input=torch.abs(hidden_states), - dim=-1, - keepdim=True - ) + per_token_amax, _ = torch.max(input=torch.abs(hidden_states), dim=-1, keepdim=True) if not dtypeMax: dtypeMax = get_dtype_max(quant_dtype) @@ -48,7 +47,7 @@ def pertoken_quant(x, y_scale_dtype=torch.float, x_scale=None, quant_dtype=torch # quant hidden_states y = (hidden_states / per_token_scale).to(dtype=quant_dtype) - y_scale = per_token_scale.to(y_scale_dtype) + y_scale = per_token_scale.to(scale_dtype) return y, y_scale @@ -57,9 +56,18 @@ def per_tensor_quant(x, scale=None, scale_dtype=torch.float, quant_dtype=torch.i if scale is None: dtypeMax = get_dtype_max(quant_dtype) scale = torch.abs(x).max() / dtypeMax - y = x/scale + y = x / scale + + return y.to(quant_dtype), scale.view(1).to(scale_dtype) - return y.to(quant_dtype), scale.to(scale_dtype) + +def get_torch_quant(qType): + tmp = { + QuantType.No: lambda *a, **k: (a[0], None), + QuantType.per_Tensor: per_tensor_quant, + QuantType.per_Token: pertoken_quant, + } + return tmp.get(qType, NotImplementedError) def per_tensor_quant_fp8_hip(x, scale=None): @@ -72,16 +80,19 @@ def per_tensor_quant_fp8_hip(x, scale=None): return y, scale +@compile_ops("module_quant", "QuantType") +def _QuantType(dummy): ... + + +QuantType = _QuantType(0) + + @compile_ops("module_quant") -def static_scaled_fp8_quant( - out: Tensor, input: Tensor, scale: Tensor -): ... +def static_scaled_fp8_quant(out: Tensor, input: Tensor, scale: Tensor): ... @compile_ops("module_quant") -def dynamic_scaled_fp8_quant( - out: Tensor, input: Tensor, scale: Tensor -):... +def dynamic_scaled_fp8_quant(out: Tensor, input: Tensor, scale: Tensor): ... @compile_ops("module_quant") diff --git a/aiter/ops/shuffle.py b/aiter/ops/shuffle.py index f075b659..55cb2a08 100644 --- a/aiter/ops/shuffle.py +++ b/aiter/ops/shuffle.py @@ -4,11 +4,11 @@ import torch -def shuffle_weight(x: torch.Tensor, layout=(16, 16)) -> torch.Tensor: +def shuffle_weight(x: torch.Tensor, layout=(16, 16), use_int4=False) -> torch.Tensor: # Hardcode BLOCK_K and BLOCK_N IN, IK = layout BK = IK*2 - K = 16//x.element_size() + K = 16//x.element_size() if not use_int4 else 32 BN = IN assert (x.shape[-2] % BN == 0), f'{x.shape[-2]} % {BN} == {x.shape[-2] % BN }' diff --git a/aiter/test_common.py b/aiter/test_common.py index d1df7af6..665ff2be 100644 --- a/aiter/test_common.py +++ b/aiter/test_common.py @@ -10,15 +10,28 @@ from aiter import logger -def perftest(num_iters=101, num_warmup=5, testGraph=False, num_rotate_args=3): +def perftest(num_iters=101, num_warmup=5, testGraph=False, num_rotate_args=0): def decorator(func): def wrapper(*args, **kwargs): run_iters(num_warmup, func, *args, **kwargs) - rotate_args = [(copy.deepcopy(args), - copy.deepcopy(kwargs)) - for _ in range(num_rotate_args)] - - if int(os.environ.get('AITER_LOG_MORE', 0)): + num = num_rotate_args + if num < 1: + current_device_index = torch.cuda.current_device() + inputSize = sum( + [el.nbytes for el in args if isinstance(el, torch.Tensor)] + ) + cache_size = ( + torch.cuda.get_device_properties(current_device_index).L2_cache_size + * 64 + * 128 + ) + num = (cache_size + inputSize - 1) // inputSize + num = min(num, num_iters) + rotate_args = [ + (copy.deepcopy(args), copy.deepcopy(kwargs)) for _ in range(num) + ] + + if int(os.environ.get("AITER_LOG_MORE", 0)): latencies = [] start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) @@ -29,32 +42,36 @@ def wrapper(*args, **kwargs): end_event.synchronize() latencies.append(start_event.elapsed_time(end_event)) avg = np.mean(latencies) * 1000 - logger.info(f'avg: {avg} us/iter from cuda.Event') + logger.info(f"avg: {avg} us/iter from cuda.Event") if testGraph: graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph): data = run_iters_rotate(num_iters, func, rotate_args) - with tpf.profile(activities=[tpf.ProfilerActivity.CPU, tpf.ProfilerActivity.CUDA], - profile_memory=True, - with_stack=True, - with_modules=True, - ) as prof: + with tpf.profile( + activities=[tpf.ProfilerActivity.CPU, tpf.ProfilerActivity.CUDA], + profile_memory=True, + with_stack=True, + with_modules=True, + ) as prof: run_iters(1, graph.replay) avg = get_trace_perf(prof, num_iters) - logger.info(f'avg: {avg} us/iter with hipgraph') - with tpf.profile(activities=[tpf.ProfilerActivity.CPU, tpf.ProfilerActivity.CUDA], - profile_memory=True, - with_stack=True, - with_modules=True, - # record_shapes=True, - # on_trace_ready=tpf.tensorboard_trace_handler( - # './aiter_logs/'), - ) as prof: + logger.info(f"avg: {avg} us/iter with hipgraph") + with tpf.profile( + activities=[tpf.ProfilerActivity.CPU, tpf.ProfilerActivity.CUDA], + profile_memory=True, + with_stack=True, + with_modules=True, + # record_shapes=True, + # on_trace_ready=tpf.tensorboard_trace_handler( + # './aiter_logs/'), + ) as prof: data = run_iters_rotate(num_iters, func, rotate_args) avg = get_trace_perf(prof, num_iters) return data, avg + return wrapper + return decorator @@ -63,7 +80,9 @@ def decorator(func): def wrapper(*args, **kwargs): log_args(func, *args, **kwargs) return func(*args, **kwargs) + return wrapper + return decorator @@ -83,119 +102,142 @@ def run_iters_rotate(num_iters, func, rotate_args): return data -def run_perftest(func, *args, num_iters=101, num_warmup=10, **kwargs): +def run_perftest( + func, *args, num_iters=101, num_warmup=10, num_rotate_args=0, **kwargs +): @perftest(num_iters=num_iters, num_warmup=num_warmup) - def worker(): + def worker(*args, **kwargs): return func(*args, **kwargs) - return worker() + + return worker(*args, **kwargs) def log_args(func, *args, **kwargs): import inspect + callargs = inspect.getcallargs(func, *args, **kwargs) prefix = f"calling {func.__name__}(" - blanks = ' '*len(prefix) + blanks = " " * len(prefix) def getTensorInfo(el): if isinstance(el, torch.Tensor): - return f'{el.shape} {el.dtype} {hex(el.data_ptr())}' + return f"{el.shape} {el.dtype} {hex(el.data_ptr())}" elif isinstance(el, tuple): viewNum = 5 if len(el) > viewNum: - el = list(el[:viewNum])+['...'] - return f'\n{" "*(len(prefix)+31)}'.join(['(']+[f" {getTensorInfo(e)}" for e in el]+[')']) + el = list(el[:viewNum]) + ["..."] + return f'\n{" "*(len(prefix)+31)}'.join( + ["("] + [f" {getTensorInfo(e)}" for e in el] + [")"] + ) return el + callargs = [f"{el:<28} = {getTensorInfo(callargs[el])}" for el in callargs] - callargs = f',\n{blanks}'.join(callargs) + callargs = f",\n{blanks}".join(callargs) logger.info(f"\n{prefix}{callargs})") def get_trace_perf(prof, num_iters): - assert (num_iters > 1) + assert num_iters > 1 num_iters -= 1 df = [] - cols = ['name', 'self_cpu_time_total', 'self_device_time_total', - 'device_type', 'device_index',] + cols = [ + "name", + "self_cpu_time_total", + "self_device_time_total", + "device_type", + "device_index", + ] for el in prof.events(): df.append([getattr(el, x, None) for x in cols]) df = pd.DataFrame(df, columns=cols) - df['cnt'] = 1 + df["cnt"] = 1 rets = [] - for name, d in df.groupby('name', sort=False): - r = d.iloc[1:][['cnt', - 'self_cpu_time_total', - 'self_device_time_total']].sum() + for name, d in df.groupby("name", sort=False): + r = d.iloc[1:][["cnt", "self_cpu_time_total", "self_device_time_total"]].sum() if not r.empty: - device_type = str(d['device_type'].iat[0]).split('.')[-1] - r['name'] = name - r['device_type'] = device_type - r['device_index'] = str(d['device_index'].iat[0]) - if device_type == 'CUDA': - r['device_time_total'] = r['self_device_time_total'] - r['host_time_total'] = 0 + device_type = str(d["device_type"].iat[0]).split(".")[-1] + r["name"] = name + r["device_type"] = device_type + r["device_index"] = str(d["device_index"].iat[0]) + if device_type == "CUDA": + r["device_time_total"] = r["self_device_time_total"] + r["host_time_total"] = 0 else: - r['host_time_total'] = r['self_device_time_total'] - r['device_time_total'] = 0 + r["host_time_total"] = r["self_device_time_total"] + r["device_time_total"] = 0 rets.append(r) df = pd.DataFrame(rets) - cols = ['name', 'cnt', 'host_time_total', 'device_time_total', - 'device_type', 'device_index',] + cols = [ + "name", + "cnt", + "host_time_total", + "device_time_total", + "device_type", + "device_index", + ] cols = [el for el in cols if el in df.columns] df = df[(df.host_time_total > 0) | (df.device_time_total > 0)] - timerList = ['host_time_total', 'device_time_total', ] + timerList = [ + "host_time_total", + "device_time_total", + ] df = df[cols].sort_values(timerList, ignore_index=True) - avg_name = '[avg us/iter]' + avg_name = "[avg us/iter]" for el in timerList: - df.at[avg_name, el] = df[el].sum()/num_iters - if int(os.environ.get('AITER_LOG_MORE', 0)): - pd.set_option('display.max_colwidth', 120) - logger.info(f'{df}') - return df.at[avg_name, 'device_time_total'] + df.at[avg_name, el] = df[el].sum() / num_iters + if int(os.environ.get("AITER_LOG_MORE", 0)): + pd.set_option("display.max_colwidth", 120) + logger.info(f"{df}") + return df.at[avg_name, "device_time_total"] -def checkAllclose(a, b, rtol=1e-2, atol=1e-2, msg='', printNum=8): +def checkAllclose(a, b, rtol=1e-2, atol=1e-2, msg="", printNum=8): isClose = torch.isclose(a, b, rtol=rtol, atol=atol) mask = ~isClose if isClose.all(): - logger.info(f'{msg}[checkAllclose {atol=} {rtol=} passed~]') - return True + logger.info(f"{msg}[checkAllclose {atol=} {rtol=} passed~]") + return 0 else: num = mask.sum() printNum = min(printNum, num) - percent = num/a.numel() - delta = (a-b)[mask] + percent = (num / a.numel()).item() + delta = (a - b)[mask] if percent > 0.01: - logger.info(f'''{msg}[checkAllclose {atol=} {rtol=} failed!] + logger.info( + f"""{msg}[checkAllclose {atol=} {rtol=} failed!] a : {a.shape} {a[mask][:printNum]} b : {b.shape} {b[mask][:printNum]} delta: - {delta[:printNum]}''') + {delta[:printNum]}""" + ) else: logger.info( - f'''{msg}[checkAllclose {atol=} {rtol=} waring!] a and b results are not all close''') + f"""{msg}[checkAllclose {atol=} {rtol=} waring!] a and b results are not all close""" + ) logger.info( - f'-->max delta:{delta.max()}, delta details: {percent:.1%} ({num} of {a.numel()}) elements') - return False + f"-->max delta:{delta.max()}, delta details: {percent:.1%} ({num} of {a.numel()}) elements" + ) + return percent -def tensor_dump(x: torch.tensor, name: str, dir='./'): +def tensor_dump(x: torch.tensor, name: str, dir="./"): x_cpu = x.cpu().view(torch.uint8) - filename = f'{dir}/{name}.bin' + filename = f"{dir}/{name}.bin" x_cpu.numpy().tofile(filename) - logger.info(f'saving {filename} {x.shape}, {x.dtype}') + logger.info(f"saving {filename} {x.shape}, {x.dtype}") - with open(f'{dir}/{name}.meta', 'w') as f: - f.writelines([f'{el}\n' for el in [x.shape, x.dtype]]) + with open(f"{dir}/{name}.meta", "w") as f: + f.writelines([f"{el}\n" for el in [x.shape, x.dtype]]) def tensor_load(filename: str): DWs = np.fromfile(filename, dtype=np.uint32) - metafile = '.'.join(filename.split('.')[:-1])+'.meta' + metafile = ".".join(filename.split(".")[:-1]) + ".meta" shape, dtype = [eval(line.strip()) for line in open(metafile)] return torch.tensor(DWs).view(dtype).view(shape) diff --git a/aiter/utility/mp_tuner.py b/aiter/utility/mp_tuner.py new file mode 100644 index 00000000..e4803bdf --- /dev/null +++ b/aiter/utility/mp_tuner.py @@ -0,0 +1,48 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +import torch +import multiprocessing as mp +import os +import pandas as pd +import aiter +import time +from aiter.test_common import run_perftest + + +def worker(gpuIDMap, tag, func, args, **kwargs): + pid = mp.current_process().pid + gpuID = gpuIDMap[pid] + args = [el.to("cpu") if isinstance(el, torch.Tensor) else el for el in args] + torch.cuda.synchronize() + + device = torch.device(f"cuda:{gpuID}") + torch.cuda.set_device(device) + args = [el.to(device) if isinstance(el, torch.Tensor) else el for el in args] + torch.cuda.synchronize() + + _, us = run_perftest(func, *args, **kwargs) + torch.cuda.synchronize() + + return tag, us, _.to("cpu") + + +def get_pid(): + time.sleep(1) + return mp.current_process().pid + + +def mp_tuner(tasks): + gpu_num = torch.cuda.device_count() + mp.set_start_method("spawn", force=True) + pool = mp.Pool(processes=gpu_num) + pids = [pool.apply_async(get_pid) for i in range(gpu_num)] + time.sleep(2) + + gpu_map = {el.get(): i for i, el in enumerate(pids)} + rets = [ + pool.apply_async(worker, args=(gpu_map, *task)) for i, task in enumerate(tasks) + ] + + pool.close() + pool.join() + return [el.get() for el in rets] diff --git a/csrc/include/aiter_hip_common.h b/csrc/include/aiter_hip_common.h index eff796af..afff1be6 100644 --- a/csrc/include/aiter_hip_common.h +++ b/csrc/include/aiter_hip_common.h @@ -52,7 +52,7 @@ class AiterAsmKernel AiterAsmKernel(const char *name, const char *hsaco) { const char *AITER_ASM_DIR = std::getenv("AITER_ASM_DIR"); - std::cout << "hipModuleLoad: " << (std::string(AITER_ASM_DIR) + hsaco).c_str() << " GetFunction: " << name; + std::cout << "[aiter] hipModuleLoad: " << (std::string(AITER_ASM_DIR) + hsaco).c_str() << " GetFunction: " << name; HIP_CALL(hipModuleLoad(&module, (std::string(AITER_ASM_DIR) + hsaco).c_str())); HIP_CALL(hipModuleGetFunction(&kernel_func, module, name)); std::cout << " Success" << std::endl; diff --git a/csrc/include/moe_op.h b/csrc/include/moe_op.h index 82a9afab..7a071dfa 100644 --- a/csrc/include/moe_op.h +++ b/csrc/include/moe_op.h @@ -128,4 +128,18 @@ void fmoe_fp8_blockscale_g1u1(torch::Tensor &out, // [ std::optional fc2_smooth_scale // [expert, 1, inter_dim] ); +void moe_stage1_fp8_g1u1(torch::Tensor &input, // [token_cnt, model_dim] M,K + torch::Tensor &w1, // [expert, inter_dim*2, model_dim] N,K + torch::Tensor &w2, // [expert, model_dim, inter_dim] + torch::Tensor &sorted_token_ids, // [max_num_tokens_padded] + torch::Tensor &sorted_weight_buf, // [max_num_tokens_padded] + torch::Tensor &sorted_expert_ids, // [max_num_m_blocks] + torch::Tensor &num_valid_ids, // [1] + torch::Tensor &out, // [token_cnt, topk, inter_dim] + std::string &kernelName, + int block_size, + std::optional a1_scale = std::nullopt, // [token_cnt, 1], token scale + std::optional w1_scale = std::nullopt // [expert, 1, inter_dim], gate(up) scale +); + void moe_sum(torch::Tensor &input, torch::Tensor &output); diff --git a/csrc/include/quant.h b/csrc/include/quant.h index 46697011..d152da5b 100644 --- a/csrc/include/quant.h +++ b/csrc/include/quant.h @@ -4,6 +4,14 @@ #include +enum class QuantType : int +{ + No, + per_Tensor, + per_Token, + per_128x128, +}; + void static_scaled_fp8_quant(torch::Tensor &out, // [..., d] torch::Tensor const &input, // [..., d] torch::Tensor const &scale); // [1] diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index b7eb27cd..9be4a0ee 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -341,7 +341,7 @@ "Aligning the number of tokens to be processed by each expert such " \ "that it is divisible by the block size."); \ m.def("fmoe", &fmoe); \ - py::enum_(m, "ActivationType") \ + py::enum_(m, "ActivationType") \ .value("Silu", ActivationType::Silu) \ .value("Gelu", ActivationType::Gelu) \ .export_values(); \ @@ -375,6 +375,17 @@ py::arg("input_scale"), \ py::arg("fc_scale_blkn") = 128, py::arg("fc_scale_blkk") = 128, \ py::arg("fc2_smooth_scale") = std::nullopt); \ + m.def("moe_stage1_fp8_g1u1", &moe_stage1_fp8_g1u1, \ + py::arg("input"), \ + py::arg("w1"), py::arg("w2"), \ + py::arg("sorted_token_ids"), py::arg("sorted_weight_buf"), \ + py::arg("sorted_expert_ids"), py::arg("num_valid_ids"), \ + py::arg("out"), \ + py::arg("kernelName"), \ + py::arg("block_size"), \ + py::arg("a1_scale") = std::nullopt, \ + py::arg("w1_scale") = std::nullopt); \ + \ m.def("moe_sum", &moe_sum, "moe_sum(Tensor! input, Tensor output) -> ()"); #define MOE_SORTING_PYBIND \ @@ -422,6 +433,12 @@ m.def("batched_rotary_embedding", &batched_rotary_embedding, "batched_rotary_embedding"); #define QUANT_PYBIND \ + py::enum_(m, "QuantType") \ + .value("No", QuantType::No) \ + .value("per_Tensor", QuantType::per_Tensor) \ + .value("per_Token", QuantType::per_Token) \ + .value("per_128x128", QuantType::per_128x128) \ + .export_values(); \ m.def("static_scaled_fp8_quant", &static_scaled_fp8_quant); \ m.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant); \ m.def("dynamic_per_token_scaled_fp8_quant", &dynamic_per_token_scaled_fp8_quant, \ diff --git a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm.hpp b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm.hpp index 6081c12f..c05d3cc5 100644 --- a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm.hpp +++ b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm.hpp @@ -25,6 +25,7 @@ using F16 = ck::half_t; using B16 = ck::bhalf_t; using F8 = ck::f8_t; using F32 = float; +using I4 = ck::pk_i4_t; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -79,20 +80,63 @@ struct MulABScale template <> __host__ __device__ constexpr void operator()(F16 &e, - const int &c, - const float &d0, - const float &d1) const + const int &c, + const float &d0, + const float &d1) const { e = ck::type_convert(ck::type_convert(c) * d1 * d0); } template <> __host__ __device__ constexpr void operator()(B16 &e, - const int &c, + const int &c, + const float &d0, + const float &d1) const + { + e = ck::type_convert(ck::type_convert(c) * d1 * d0); + } +}; + +struct MulABScaleWint4 +{ + template + __host__ __device__ constexpr void + operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const; + + template <> + __host__ __device__ constexpr void operator()(F16 &e, + const float &c, const float &d0, const float &d1) const { - e = ck::type_convert(ck::type_convert(c) * d1 * d0); + e = ck::type_convert(c * d1 * d0 * 16.f); + } + + template <> + __host__ __device__ constexpr void operator()(B16 &e, + const float &c, + const float &d0, + const float &d1) const + { + e = ck::type_convert(c * d1 * d0 * 16.f); + } + + template <> + __host__ __device__ constexpr void operator()(F16 &e, + const int &c, + const float &d0, + const float &d1) const + { + e = ck::type_convert(ck::type_convert(c) * d1 * d0 * 16.f); + } + + template <> + __host__ __device__ constexpr void operator()(B16 &e, + const int &c, + const float &d0, + const float &d1) const + { + e = ck::type_convert(ck::type_convert(c) * d1 * d0 * 16.f); } }; @@ -122,26 +166,26 @@ struct TypeCastExpertWeight template <> __host__ __device__ constexpr void operator()(F16 &e, - const int &c, - const float &d0, - const float &d1, - const float &d2) const + const int &c, + const float &d0, + const float &d1, + const float &d2) const { e = ck::type_convert(ck::type_convert(c) * d2); } template <> __host__ __device__ constexpr void operator()(B16 &e, - const int &c, - const float &d0, - const float &d1, - const float &d2) const + const int &c, + const float &d0, + const float &d1, + const float &d2) const { e = ck::type_convert(ck::type_convert(c) * d2); } }; // d0: ascale, d1: bscale, d2:expert weight -//warning: hack hack hack here!!!! ignore d0 right now as kernel mul d0 * d2 outside. tofix:felix +// warning: hack hack hack here!!!! ignore d0 right now as kernel mul d0 * d2 outside. tofix:felix struct MulABScaleExpertWeight { template @@ -166,27 +210,73 @@ struct MulABScaleExpertWeight e = ck::type_convert(c * d1 * d2); } - template <> - __host__ __device__ constexpr void operator()(F16 &e, - const int &c, - const float &d0, - const float &d1, - const float &d2) const - { - e = ck::type_convert(ck::type_convert(c) * d1 * d2); - } - template <> - __host__ __device__ constexpr void operator()(B16 &e, - const int &c, - const float &d0, - const float &d1, - const float &d2) const - { - e = ck::type_convert(ck::type_convert(c) * d1 * d2); - } + template <> + __host__ __device__ constexpr void operator()(F16 &e, + const int &c, + const float &d0, + const float &d1, + const float &d2) const + { + e = ck::type_convert(ck::type_convert(c) * d1 * d2); + } + template <> + __host__ __device__ constexpr void operator()(B16 &e, + const int &c, + const float &d0, + const float &d1, + const float &d2) const + { + e = ck::type_convert(ck::type_convert(c) * d1 * d2); + } +}; + +// d0: ascale, d1: bscale, d2:expert weight +// warning: hack hack hack here!!!! ignore d0 right now as kernel mul d0 * d2 outside. tofix:felix +struct MulABScaleExpertWeightWin4 +{ + template + __host__ __device__ constexpr void + operator()(E &e, const C &c, const D0 &d0, const D1 &d1, const D2 &d2) const; + template <> + __host__ __device__ constexpr void operator()(F16 &e, + const float &c, + const float &d0, + const float &d1, + const float &d2) const + { + e = ck::type_convert(c * d1 * d2 * 16.f); + } + template <> + __host__ __device__ constexpr void operator()(B16 &e, + const float &c, + const float &d0, + const float &d1, + const float &d2) const + { + e = ck::type_convert(c * d1 * d2 * 16.f); + } + + template <> + __host__ __device__ constexpr void operator()(F16 &e, + const int &c, + const float &d0, + const float &d1, + const float &d2) const + { + e = ck::type_convert(ck::type_convert(c) * d1 * d2 * 16.f); + } + template <> + __host__ __device__ constexpr void operator()(B16 &e, + const int &c, + const float &d0, + const float &d1, + const float &d2) const + { + e = ck::type_convert(ck::type_convert(c) * d1 * d2 * 16.f); + } }; -template +template void ck_moe_stage1_gemm(const hipStream_t &stream, int tokens, int sorted_size, int N, int K, int topk, void *&hidden_states, // [m, k], input token @@ -200,7 +290,7 @@ void ck_moe_stage1_gemm(const hipStream_t &stream, int tokens, int sorted_size, std::optional a1_scale = std::nullopt // [m, 1], token scale ); -template +template void ck_moe_stage2_gemm(const hipStream_t &stream, int tokens, int sorted_size, int N, int K, int topk, void *&inter_states, // [max_num_tokens_padded, k], input token diff --git a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_b16_f8.cu b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_b16_f8.cu deleted file mode 100644 index a796f600..00000000 --- a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_b16_f8.cu +++ /dev/null @@ -1,16 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. -#include "moe_ck_gemm_common.cuh" - -using A0DataType = F8; -using B0DataType = F8; -using AccDataType = F32; -using EDataType = B16; -using CDEElementOp = MulABScale; - -CK_MOE_STAGE1_GEMM_DEFINE(32, 256, 1, 1, false) -CK_MOE_STAGE1_GEMM_DEFINE(64, 256, 2, 1, false) -CK_MOE_STAGE1_GEMM_DEFINE(128, 128, 2, 2, false) -// CK_MOE_STAGE1_GEMM_DEFINE(32, 256, 1, 1, true) -// CK_MOE_STAGE1_GEMM_DEFINE(64, 256, 2, 1, true) -// CK_MOE_STAGE1_GEMM_DEFINE(128, 128, 2, 2, true) diff --git a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_f16_f8.cu b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_f16_f8.cu deleted file mode 100644 index af64f7d1..00000000 --- a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_f16_f8.cu +++ /dev/null @@ -1,16 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. -#include "moe_ck_gemm_common.cuh" - -using A0DataType = F8; -using B0DataType = F8; -using AccDataType = F32; -using EDataType = F16; -using CDEElementOp = MulABScale; - -CK_MOE_STAGE1_GEMM_DEFINE(32, 256, 1, 1, false) -CK_MOE_STAGE1_GEMM_DEFINE(64, 256, 2, 1, false) -CK_MOE_STAGE1_GEMM_DEFINE(128, 128, 2, 2, false) -// CK_MOE_STAGE1_GEMM_DEFINE(32, 256, 1, 1, true) -// CK_MOE_STAGE1_GEMM_DEFINE(64, 256, 2, 1, true) -// CK_MOE_STAGE1_GEMM_DEFINE(128, 128, 2, 2, true) \ No newline at end of file diff --git a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_b16.cu b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_pertensor_b16.cu similarity index 52% rename from csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_b16.cu rename to csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_pertensor_b16.cu index fafd0f15..5aae2719 100644 --- a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_b16.cu +++ b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_pertensor_b16.cu @@ -7,7 +7,8 @@ using B0DataType = B16; using AccDataType = F32; using EDataType = B16; using CDEElementOp = TypeCast; - -CK_MOE_STAGE1_GEMM_DEFINE(32, 128, 1, 1, false) -CK_MOE_STAGE1_GEMM_DEFINE(64, 128, 2, 1, false) -CK_MOE_STAGE1_GEMM_DEFINE(128, 64, 2, 2, false) +const bool Nswizzle = false; +const bool PerTensorQuant = true; +CK_MOE_STAGE1_GEMM_DEFINE(32, 256/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE1_GEMM_DEFINE(64, 256/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE1_GEMM_DEFINE(128, 128/sizeof(A0DataType), 2, 2) diff --git a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_pertensor_b16_f8.cu b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_pertensor_b16_f8.cu new file mode 100644 index 00000000..f5d25e89 --- /dev/null +++ b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_pertensor_b16_f8.cu @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#include "moe_ck_gemm_common.cuh" + +using A0DataType = F8; +using B0DataType = F8; +using AccDataType = F32; +using EDataType = B16; +using CDEElementOp = MulABScale; +const bool Nswizzle = false; +const bool PerTensorQuant = true; +CK_MOE_STAGE1_GEMM_DEFINE(32, 256/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE1_GEMM_DEFINE(64, 256/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE1_GEMM_DEFINE(128, 128/sizeof(A0DataType), 2, 2) diff --git a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_pertensor_b16_f8_wint4.cu b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_pertensor_b16_f8_wint4.cu new file mode 100644 index 00000000..c97d5e3f --- /dev/null +++ b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_pertensor_b16_f8_wint4.cu @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#include "moe_ck_gemm_common.cuh" + +using A0DataType = F8; +using B0DataType = I4; +using AccDataType = F32; +using EDataType = B16; +using CDEElementOp = MulABScaleWint4; +const bool Nswizzle = false; +const bool PerTensorQuant = true; +CK_MOE_STAGE1_GEMM_DEFINE(32, 128/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE1_GEMM_DEFINE(64, 128/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE1_GEMM_DEFINE(128, 128/sizeof(A0DataType), 1, 4) diff --git a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_f16.cu b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_pertensor_f16.cu similarity index 52% rename from csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_f16.cu rename to csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_pertensor_f16.cu index 0edd0337..cf834c70 100644 --- a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_f16.cu +++ b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_pertensor_f16.cu @@ -7,7 +7,9 @@ using B0DataType = F16; using AccDataType = F32; using EDataType = F16; using CDEElementOp = TypeCast; +const bool Nswizzle = false; +const bool PerTensorQuant = true; +CK_MOE_STAGE1_GEMM_DEFINE(32, 256/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE1_GEMM_DEFINE(64, 256/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE1_GEMM_DEFINE(128, 128/sizeof(A0DataType), 2, 2) -CK_MOE_STAGE1_GEMM_DEFINE(32, 128, 1, 1, false) -CK_MOE_STAGE1_GEMM_DEFINE(64, 128, 2, 1, false) -CK_MOE_STAGE1_GEMM_DEFINE(128, 64, 2, 2, false) diff --git a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_pertensor_f16_f8.cu b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_pertensor_f16_f8.cu new file mode 100644 index 00000000..218b2a5e --- /dev/null +++ b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_pertensor_f16_f8.cu @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#include "moe_ck_gemm_common.cuh" + +using A0DataType = F8; +using B0DataType = F8; +using AccDataType = F32; +using EDataType = F16; +using CDEElementOp = MulABScale; +const bool Nswizzle = false; +const bool PerTensorQuant = true; +CK_MOE_STAGE1_GEMM_DEFINE(32, 256/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE1_GEMM_DEFINE(64, 256/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE1_GEMM_DEFINE(128, 128/sizeof(A0DataType), 2, 2) diff --git a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_pertensor_f16_f8_win4.cu b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_pertensor_f16_f8_win4.cu new file mode 100644 index 00000000..8be8a8d6 --- /dev/null +++ b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_pertensor_f16_f8_win4.cu @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#include "moe_ck_gemm_common.cuh" + +using A0DataType = F8; +using B0DataType = I4; +using AccDataType = F32; +using EDataType = F16; +using CDEElementOp = MulABScaleWint4; +const bool Nswizzle = false; +const bool PerTensorQuant = true; +CK_MOE_STAGE1_GEMM_DEFINE(32, 128/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE1_GEMM_DEFINE(64, 128/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE1_GEMM_DEFINE(128, 128/sizeof(A0DataType), 1, 4) diff --git a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_pertoken_b16.cu b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_pertoken_b16.cu new file mode 100644 index 00000000..fa8e17bb --- /dev/null +++ b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_pertoken_b16.cu @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#include "moe_ck_gemm_common.cuh" + +using A0DataType = B16; +using B0DataType = B16; +using AccDataType = F32; +using EDataType = B16; +using CDEElementOp = TypeCast; +const bool Nswizzle = false; +const bool PerTensorQuant = false; +CK_MOE_STAGE1_GEMM_DEFINE(32, 256/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE1_GEMM_DEFINE(64, 256/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE1_GEMM_DEFINE(128, 128/sizeof(A0DataType), 2, 2) + diff --git a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_pertoken_b16_f8.cu b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_pertoken_b16_f8.cu new file mode 100644 index 00000000..3ec61461 --- /dev/null +++ b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_pertoken_b16_f8.cu @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#include "moe_ck_gemm_common.cuh" + +using A0DataType = F8; +using B0DataType = F8; +using AccDataType = F32; +using EDataType = B16; +using CDEElementOp = MulABScale; +const bool Nswizzle = false; +const bool PerTensorQuant = false; +CK_MOE_STAGE1_GEMM_DEFINE(32, 256/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE1_GEMM_DEFINE(64, 256/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE1_GEMM_DEFINE(128, 128/sizeof(A0DataType), 2, 2) + diff --git a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_pertoken_b16_f8_wint4.cu b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_pertoken_b16_f8_wint4.cu new file mode 100644 index 00000000..06732f65 --- /dev/null +++ b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_pertoken_b16_f8_wint4.cu @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#include "moe_ck_gemm_common.cuh" + +using A0DataType = F8; +using B0DataType = I4; +using AccDataType = F32; +using EDataType = B16; +using CDEElementOp = MulABScaleWint4; +const bool Nswizzle = false; +const bool PerTensorQuant = false; +CK_MOE_STAGE1_GEMM_DEFINE(32, 128/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE1_GEMM_DEFINE(64, 128/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE1_GEMM_DEFINE(128, 128/sizeof(A0DataType), 1, 4) diff --git a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_pertoken_f16.cu b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_pertoken_f16.cu new file mode 100644 index 00000000..d865efcd --- /dev/null +++ b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_pertoken_f16.cu @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#include "moe_ck_gemm_common.cuh" + +using A0DataType = F16; +using B0DataType = F16; +using AccDataType = F32; +using EDataType = F16; +using CDEElementOp = TypeCast; +const bool Nswizzle = false; +const bool PerTensorQuant = false; +CK_MOE_STAGE1_GEMM_DEFINE(32, 256/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE1_GEMM_DEFINE(64, 256/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE1_GEMM_DEFINE(128, 128/sizeof(A0DataType), 2, 2) + diff --git a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_pertoken_f16_f8.cu b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_pertoken_f16_f8.cu new file mode 100644 index 00000000..103fa17a --- /dev/null +++ b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_pertoken_f16_f8.cu @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#include "moe_ck_gemm_common.cuh" + +using A0DataType = F8; +using B0DataType = F8; +using AccDataType = F32; +using EDataType = F16; +using CDEElementOp = MulABScale; +const bool Nswizzle = false; +const bool PerTensorQuant = false; +CK_MOE_STAGE1_GEMM_DEFINE(32, 256/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE1_GEMM_DEFINE(64, 256/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE1_GEMM_DEFINE(128, 128/sizeof(A0DataType), 2, 2) + diff --git a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_pertoken_f16_f8_win4.cu b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_pertoken_f16_f8_win4.cu new file mode 100644 index 00000000..8c5c6d9d --- /dev/null +++ b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm1_instance_pertoken_f16_f8_win4.cu @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#include "moe_ck_gemm_common.cuh" + +using A0DataType = F8; +using B0DataType = I4; +using AccDataType = F32; +using EDataType = F16; +using CDEElementOp = MulABScaleWint4; +const bool Nswizzle = false; +const bool PerTensorQuant = false; +CK_MOE_STAGE1_GEMM_DEFINE(32, 128/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE1_GEMM_DEFINE(64, 128/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE1_GEMM_DEFINE(128, 128/sizeof(A0DataType), 1, 4) diff --git a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_b16..cu b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_pertensor_b16..cu similarity index 53% rename from csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_b16..cu rename to csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_pertensor_b16..cu index 716abf0c..14a43b86 100644 --- a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_b16..cu +++ b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_pertensor_b16..cu @@ -7,7 +7,9 @@ using B0DataType = B16; using AccDataType = F32; using EDataType = B16; using CDEElementOp = TypeCastExpertWeight; +const bool Nswizzle = false; +const bool PerTensorQuant = true; +CK_MOE_STAGE2_GEMM_DEFINE(32, 256/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE2_GEMM_DEFINE(64, 256/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE2_GEMM_DEFINE(128, 128/sizeof(A0DataType), 2, 2) -CK_MOE_STAGE2_GEMM_DEFINE(32) -CK_MOE_STAGE2_GEMM_DEFINE(64) -CK_MOE_STAGE2_GEMM_DEFINE(128) diff --git a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_b16_f8.cu b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_pertensor_b16_f8.cu similarity index 53% rename from csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_b16_f8.cu rename to csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_pertensor_b16_f8.cu index 235681f9..6d8705a9 100644 --- a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_b16_f8.cu +++ b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_pertensor_b16_f8.cu @@ -7,7 +7,9 @@ using B0DataType = F8; using AccDataType = F32; using EDataType = B16; using CDEElementOp = MulABScaleExpertWeight; +const bool Nswizzle = false; +const bool PerTensorQuant = true; +CK_MOE_STAGE2_GEMM_DEFINE(32, 256/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE2_GEMM_DEFINE(64, 256/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE2_GEMM_DEFINE(128, 128/sizeof(A0DataType), 2, 2) -CK_MOE_STAGE2_GEMM_DEFINE(32) -CK_MOE_STAGE2_GEMM_DEFINE(64) -CK_MOE_STAGE2_GEMM_DEFINE(128) diff --git a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_pertensor_b16_f8_wint4.cu b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_pertensor_b16_f8_wint4.cu new file mode 100644 index 00000000..8997d734 --- /dev/null +++ b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_pertensor_b16_f8_wint4.cu @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#include "moe_ck_gemm_common.cuh" + +using A0DataType = F8; +using B0DataType = I4; +using AccDataType = F32; +using EDataType = B16; +using CDEElementOp = MulABScaleExpertWeightWin4; +const bool Nswizzle = false; +const bool PerTensorQuant = true; +CK_MOE_STAGE2_GEMM_DEFINE(32, 128/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE2_GEMM_DEFINE(64, 128/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE2_GEMM_DEFINE(128, 128/sizeof(A0DataType), 1, 4) diff --git a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_f16.cu b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_pertensor_f16.cu similarity index 53% rename from csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_f16.cu rename to csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_pertensor_f16.cu index e326412b..aa5be991 100644 --- a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_f16.cu +++ b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_pertensor_f16.cu @@ -7,8 +7,10 @@ using B0DataType = F16; using AccDataType = F32; using EDataType = F16; using CDEElementOp = TypeCastExpertWeight; +const bool Nswizzle = false; +const bool PerTensorQuant = true; +CK_MOE_STAGE2_GEMM_DEFINE(32, 256/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE2_GEMM_DEFINE(64, 256/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE2_GEMM_DEFINE(128, 128/sizeof(A0DataType), 2, 2) -CK_MOE_STAGE2_GEMM_DEFINE(32) -CK_MOE_STAGE2_GEMM_DEFINE(64) -CK_MOE_STAGE2_GEMM_DEFINE(128) diff --git a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_f16_f8.cu b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_pertensor_f16_f8.cu similarity index 53% rename from csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_f16_f8.cu rename to csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_pertensor_f16_f8.cu index 9a8ea43f..b4f7b1ea 100644 --- a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_f16_f8.cu +++ b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_pertensor_f16_f8.cu @@ -7,7 +7,9 @@ using B0DataType = F8; using AccDataType = F32; using EDataType = F16; using CDEElementOp = MulABScaleExpertWeight; +const bool Nswizzle = false; +const bool PerTensorQuant = true; +CK_MOE_STAGE2_GEMM_DEFINE(32, 256/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE2_GEMM_DEFINE(64, 256/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE2_GEMM_DEFINE(128, 128/sizeof(A0DataType), 2, 2) -CK_MOE_STAGE2_GEMM_DEFINE(32) -CK_MOE_STAGE2_GEMM_DEFINE(64) -CK_MOE_STAGE2_GEMM_DEFINE(128) diff --git a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_pertensor_f16_f8_wint4.cu b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_pertensor_f16_f8_wint4.cu new file mode 100644 index 00000000..6ef870d3 --- /dev/null +++ b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_pertensor_f16_f8_wint4.cu @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#include "moe_ck_gemm_common.cuh" + +using A0DataType = F8; +using B0DataType = I4; +using AccDataType = F32; +using EDataType = F16; +using CDEElementOp = MulABScaleExpertWeightWin4; +const bool Nswizzle = false; +const bool PerTensorQuant = true; +CK_MOE_STAGE2_GEMM_DEFINE(32, 128/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE2_GEMM_DEFINE(64, 128/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE2_GEMM_DEFINE(128, 128/sizeof(A0DataType), 1, 4) diff --git a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_pertoken_b16..cu b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_pertoken_b16..cu new file mode 100644 index 00000000..d63ab951 --- /dev/null +++ b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_pertoken_b16..cu @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#include "moe_ck_gemm_common.cuh" + +using A0DataType = B16; +using B0DataType = B16; +using AccDataType = F32; +using EDataType = B16; +using CDEElementOp = TypeCastExpertWeight; +const bool Nswizzle = false; +const bool PerTensorQuant = false; +CK_MOE_STAGE2_GEMM_DEFINE(32, 256/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE2_GEMM_DEFINE(64, 256/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE2_GEMM_DEFINE(128, 128/sizeof(A0DataType), 2, 2) + diff --git a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_pertoken_b16_f8.cu b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_pertoken_b16_f8.cu new file mode 100644 index 00000000..8e82b29a --- /dev/null +++ b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_pertoken_b16_f8.cu @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#include "moe_ck_gemm_common.cuh" + +using A0DataType = F8; +using B0DataType = F8; +using AccDataType = F32; +using EDataType = B16; +using CDEElementOp = MulABScaleExpertWeight; +const bool Nswizzle = false; +const bool PerTensorQuant = false; +CK_MOE_STAGE2_GEMM_DEFINE(32, 256/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE2_GEMM_DEFINE(64, 256/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE2_GEMM_DEFINE(128, 128/sizeof(A0DataType), 2, 2) + diff --git a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_pertoken_b16_f8_wint4.cu b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_pertoken_b16_f8_wint4.cu new file mode 100644 index 00000000..ef3592db --- /dev/null +++ b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_pertoken_b16_f8_wint4.cu @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#include "moe_ck_gemm_common.cuh" + +using A0DataType = F8; +using B0DataType = I4; +using AccDataType = F32; +using EDataType = B16; +using CDEElementOp = MulABScaleExpertWeightWin4; +const bool Nswizzle = false; +const bool PerTensorQuant = false; +CK_MOE_STAGE2_GEMM_DEFINE(32, 128/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE2_GEMM_DEFINE(64, 128/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE2_GEMM_DEFINE(128, 128/sizeof(A0DataType), 1, 4) diff --git a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_pertoken_f16.cu b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_pertoken_f16.cu new file mode 100644 index 00000000..28d5d3aa --- /dev/null +++ b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_pertoken_f16.cu @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#include "moe_ck_gemm_common.cuh" + +using A0DataType = F16; +using B0DataType = F16; +using AccDataType = F32; +using EDataType = F16; +using CDEElementOp = TypeCastExpertWeight; +const bool Nswizzle = false; +const bool PerTensorQuant = false; +CK_MOE_STAGE2_GEMM_DEFINE(32, 256/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE2_GEMM_DEFINE(64, 256/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE2_GEMM_DEFINE(128, 128/sizeof(A0DataType), 2, 2) + diff --git a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_pertoken_f16_f8.cu b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_pertoken_f16_f8.cu new file mode 100644 index 00000000..1fbd5c7f --- /dev/null +++ b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_pertoken_f16_f8.cu @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#include "moe_ck_gemm_common.cuh" + +using A0DataType = F8; +using B0DataType = F8; +using AccDataType = F32; +using EDataType = F16; +using CDEElementOp = MulABScaleExpertWeight; +const bool Nswizzle = false; +const bool PerTensorQuant = false; +CK_MOE_STAGE2_GEMM_DEFINE(32, 256/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE2_GEMM_DEFINE(64, 256/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE2_GEMM_DEFINE(128, 128/sizeof(A0DataType), 2, 2) + diff --git a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_pertoken_f16_f8_wint4.cu b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_pertoken_f16_f8_wint4.cu new file mode 100644 index 00000000..d4927be6 --- /dev/null +++ b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm2_instance_pertoken_f16_f8_wint4.cu @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#include "moe_ck_gemm_common.cuh" + +using A0DataType = F8; +using B0DataType = I4; +using AccDataType = F32; +using EDataType = F16; +using CDEElementOp = MulABScaleExpertWeightWin4; +const bool Nswizzle = false; +const bool PerTensorQuant = false; +CK_MOE_STAGE2_GEMM_DEFINE(32, 128/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE2_GEMM_DEFINE(64, 128/sizeof(A0DataType), 1, 4) +CK_MOE_STAGE2_GEMM_DEFINE(128, 128/sizeof(A0DataType), 1, 4) diff --git a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm_common.cuh b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm_common.cuh index 5bb443d0..e7ccfdd5 100644 --- a/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm_common.cuh +++ b/csrc/py_itfs_ck/moe_ck_2stages_gemm_impl/moe_ck_gemm_common.cuh @@ -2,8 +2,8 @@ // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "moe_ck_gemm.hpp" - -template +#include +template void ck_moe_stage1_gemm(const hipStream_t &stream, int tokens, int sorted_size, int N, int K, int topk, void *&hidden_states, // [m, k], input token @@ -23,7 +23,6 @@ void ck_moe_stage1_gemm(const hipStream_t &stream, int tokens, int sorted_size, ck::index_t StrideD = 0; ck::index_t StrideE = N; ck::index_t KBatch = 1; - static_assert(NSwizzle==false, "disabled. need other prs to be ready."); // using AccDataType = F32; using CShuffleDataType = F32; using DsDataType = ck::Tuple; @@ -43,13 +42,25 @@ void ck_moe_stage1_gemm(const hipStream_t &stream, int tokens, int sorted_size, static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; // static constexpr ck::index_t MPerBlock = 128; static constexpr ck::index_t MNPerXDL = 32; + static constexpr ck::index_t BLOCKSIZE = 256; + static constexpr ck::index_t NPerBlock = 128; + static constexpr ck::index_t WAVES = BLOCKSIZE / 64; + // static constexpr ck::index_t MWaves = 1; + // static constexpr ck::index_t NWaves = WAVES / MWaves; + static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * MWaves); + static constexpr ck::index_t NXDLPerWave = NPerBlock / (MNPerXDL * NWaves); + static constexpr ck::index_t CShuffleMXDLPerWave = ck::is_same_v ? 1 : MXDLPerWave; + // static constexpr ck::index_t KPerBlock = ck::is_same_v ? 128 : 256 / sizeof(A0DataType); static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); - static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); + static constexpr ck::index_t BK1 = ck::is_same_v ? 32 : 16 / sizeof(B0DataType); static constexpr ck::index_t EVec = 16 / sizeof(EDataType); + static constexpr ck::index_t K0_A = KPerBlock / AK1; + static constexpr ck::index_t K0_B = KPerBlock / BK1; + static constexpr ck::index_t K0_M_A = BLOCKSIZE / K0_A; + static constexpr ck::index_t K0_N_B = BLOCKSIZE / K0_B; static constexpr ck::index_t D0Vec = 1; - static constexpr ck::index_t D1Vec = 1; + static constexpr ck::index_t D1Vec = PerTensorQuant ? 1 : EVec; - // using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // clang-format off ///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -63,23 +74,23 @@ void ck_moe_stage1_gemm(const hipStream_t &stream, int tokens, int sorted_size, < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, //threadnum, mblock, nblock, kblock - 256, MPerBlock, 128, KPerBlock, + 256, MPerBlock, NPerBlock, KPerBlock, // ak1, bk1 AK1, BK1, // mn_perxdl MNPerXDL, MNPerXDL, // mn_xdlperwave - MWAVE, NWAVE, + MXDLPerWave, NXDLPerWave, // a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra // S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, // S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, + S, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, + S, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - MWAVE, 1, S<1, 32, 1, 8>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, NSwizzle, true, A0DataType>; + CShuffleMXDLPerWave, 1, S<1, 32, 1, 8>, S, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>; // kernel 2: 128->32x128x128 // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; @@ -92,6 +103,8 @@ void ck_moe_stage1_gemm(const hipStream_t &stream, int tokens, int sorted_size, constexpr ck::index_t NumDTensor = DsDataType::Size(); constexpr auto I0 = ck::Number<0>{}; + constexpr auto I1 = ck::Number<1>{}; + static constexpr auto DStride = PerTensorQuant ? I0 : I1; // do GEMM auto device_op = DeviceOpInstance{}; @@ -113,7 +126,7 @@ void ck_moe_stage1_gemm(const hipStream_t &stream, int tokens, int sorted_size, K, StrideA, StrideB, - std::array{I0, I0}, + std::array{DStride, DStride}, StrideE, KBatch, a_element_op, @@ -130,22 +143,22 @@ void ck_moe_stage1_gemm(const hipStream_t &stream, int tokens, int sorted_size, invoker.Run(argument, StreamConfig{stream}); } -#define CK_MOE_STAGE1_GEMM_DEFINE(MPerBlock, KPerBlock, MWaves, NWaves, NSwizzle) \ - template void ck_moe_stage1_gemm( \ - const hipStream_t &stream, \ - int tokens, int sorted_size, int N, int K, \ - int topk, \ - void *&hidden_states, \ - void *&w1, \ - void *&w2, \ - void *&sorted_token_ids, \ - void *&sorted_expert_ids, \ - void *&num_valid_ids, \ - void *&out, \ - std::optional w1_scale, \ +#define CK_MOE_STAGE1_GEMM_DEFINE(MPerfBlock, KPerBlock, MWaves, NWaves) \ + template void ck_moe_stage1_gemm( \ + const hipStream_t &stream, \ + int tokens, int sorted_size, int N, int K, \ + int topk, \ + void *&hidden_states, \ + void *&w1, \ + void *&w2, \ + void *&sorted_token_ids, \ + void *&sorted_expert_ids, \ + void *&num_valid_ids, \ + void *&out, \ + std::optional w1_scale, \ std::optional a1_scale); -template +template void ck_moe_stage2_gemm(const hipStream_t &stream, int tokens, int sorted_size, int N, int K, int topk, void *&inter_states, // [max_num_tokens_padded, k], input token @@ -182,58 +195,61 @@ void ck_moe_stage2_gemm(const hipStream_t &stream, int tokens, int sorted_size, using AElementOp = PassThrough; using BElementOp = PassThrough; // using CDEElementOp = MultiplyMultiply; - static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; - // static constexpr ck::index_t MPerBlock = 128; static constexpr ck::index_t BLOCKSIZE = 256; + static constexpr ck::index_t WAVES = BLOCKSIZE / 64; static constexpr ck::index_t NPerBlock = 128; static constexpr ck::index_t MNPerXDL = 32; - static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); - static constexpr ck::index_t MXDLPerWave = MPerBlock <= 64 ? MPerBlock / 32 : MPerBlock / 64; - static constexpr ck::index_t NXDLPerWave = MPerBlock <= 64 ? 1 : 2; - static constexpr ck::index_t CShuffleMXDLPerWave = MXDLPerWave; - static constexpr ck::index_t CShuffleNLane = NPerBlock / 2 / NXDLPerWave; + // static constexpr ck::index_t MWaves = 1; + // static constexpr ck::index_t NWaves = WAVES / MWaves; + static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * MWaves); + static constexpr ck::index_t NXDLPerWave = NPerBlock / (MNPerXDL * NWaves); + // static constexpr ck::index_t KPerBlock = ck::is_same_v ? 128 : 256 / sizeof(A0DataType); + static constexpr ck::index_t CShuffleMXDLPerWave = ck::is_same_v ? 1 : MXDLPerWave; + static constexpr ck::index_t CShuffleNLane = ck::is_same_v ? 32 : NPerBlock / 2 / NXDLPerWave; // 64 static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane; static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); - static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); + static constexpr ck::index_t BK1 = ck::is_same_v ? 32 / sizeof(B0DataType) : 16 / sizeof(B0DataType); static constexpr ck::index_t EVec = 2; static constexpr ck::index_t D0Vec = 1; - static constexpr ck::index_t D1Vec = 1; + static constexpr ck::index_t D1Vec = PerTensorQuant ? 1 : EVec; static constexpr ck::index_t D2Vec = 1; + static constexpr ck::index_t K0_A = KPerBlock / AK1; + static constexpr ck::index_t K0_B = KPerBlock / BK1; + static constexpr ck::index_t K0_M = BLOCKSIZE / K0_A; + static constexpr ck::index_t K0_N = BLOCKSIZE / K0_B; + using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // clang-format off -///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| -///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| -///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S| -///###### RCR - // kernel 1: 256->32x128x128 - // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; - // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, EDataType>; - < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, - AElementOp, BElementOp, CDEElementOp, GemmSpec, - //threadnum, mblock, nblock, kblock - BLOCKSIZE, MPerBlock, NPerBlock, KPerBlock, - // ak1, bk1 - AK1, BK1, - // mn_perxdl - MNPerXDL, MNPerXDL, - // mn_xdlperwave - MXDLPerWave, NXDLPerWave, - // a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra - // S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, - // S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, - // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - CShuffleMXDLPerWave, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>; - // kernel 2: 128->32x128x128 - // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; +///#####| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +///#####| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +///#####| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +///#####| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S| +///##### RCR + // kernel 1: 256->32x128x128 + // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; + // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, EDataType>; + < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + //threadnum, mblock, nblock, kblock + BLOCKSIZE, MPerBlock, NPerBlock, KPerBlock, + // ak1, bk1 + AK1, BK1, + // mn_perxdl + MNPerXDL, MNPerXDL, + // mn_xdlperwave + MXDLPerWave, NXDLPerWave, + // a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra + S, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, + S, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, + // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + CShuffleMXDLPerWave, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, false, A0DataType>; + // kernel 2: 128->32x128x128 + // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; - // clang-format on auto a_element_op = AElementOp{}; auto b_element_op = BElementOp{}; @@ -242,6 +258,8 @@ void ck_moe_stage2_gemm(const hipStream_t &stream, int tokens, int sorted_size, constexpr ck::index_t NumDTensor = DsDataType::Size(); constexpr auto I0 = ck::Number<0>{}; + constexpr auto I1 = ck::Number<1>{}; + static constexpr auto DStride = PerTensorQuant ? I0 : I1; // do GEMM auto device_op = DeviceOpInstance{}; @@ -264,7 +282,7 @@ void ck_moe_stage2_gemm(const hipStream_t &stream, int tokens, int sorted_size, K, StrideA, StrideB, - std::array{I0, I0, I0}, + std::array{DStride, DStride, I0}, StrideE, KBatch, a_element_op, @@ -280,18 +298,18 @@ void ck_moe_stage2_gemm(const hipStream_t &stream, int tokens, int sorted_size, invoker.Run(argument, StreamConfig{stream}); } -#define CK_MOE_STAGE2_GEMM_DEFINE(MPerfBlock) \ - template void ck_moe_stage2_gemm( \ - const hipStream_t &stream, \ - int tokens, int sorted_size, int N, int K, \ - int topk, \ - void *&inter_states, \ - void *&w1, \ - void *&w2, \ - void *&sorted_token_ids, \ - void *&sorted_expert_ids, \ - void *&sorted_weights, \ - void *&num_valid_ids, \ - void *&out, \ - std::optional w2_scale, \ +#define CK_MOE_STAGE2_GEMM_DEFINE(MPerfBlock, KPerBlock, MWaves, NWaves) \ + template void ck_moe_stage2_gemm( \ + const hipStream_t &stream, \ + int tokens, int sorted_size, int N, int K, \ + int topk, \ + void *&inter_states, \ + void *&w1, \ + void *&w2, \ + void *&sorted_token_ids, \ + void *&sorted_expert_ids, \ + void *&sorted_weights, \ + void *&num_valid_ids, \ + void *&out, \ + std::optional w2_scale, \ std::optional a2_scale); \ No newline at end of file diff --git a/csrc/py_itfs_ck/moe_ck_2stages_kernel.cu b/csrc/py_itfs_ck/moe_ck_2stages_kernel.cu index 7774e9d5..7b0ba996 100644 --- a/csrc/py_itfs_ck/moe_ck_2stages_kernel.cu +++ b/csrc/py_itfs_ck/moe_ck_2stages_kernel.cu @@ -4,17 +4,47 @@ #include #include #include "py_itfs_common.h" - #include "moe_ck_gemm.hpp" -#define CK_MOE_STAGE1_GEMM_IMPL(A0DataType, B0DataType, AccDataType, EDataType, CDEElementOp, MPerBlock, NSwizzle) \ - if (MPerBlock == 32) \ - ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ - else if (MPerBlock == 64) \ - ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ - else if (MPerBlock == 128) \ - ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); +#define CK_MOE_STAGE1_GEMM_IMPL(A0DataType, B0DataType, AccDataType, EDataType, CDEElementOp, Nswizzle, isPerTensorQuant, MPerBlock) \ + if (isPerTensorQuant) \ + { \ + if (MPerBlock == 32) \ + ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + else if (MPerBlock == 64) \ + ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + else if (MPerBlock == 128) \ + ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + } \ + else \ + { \ + if (MPerBlock == 32) \ + ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + else if (MPerBlock == 64) \ + ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + else if (MPerBlock == 128) \ + ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + } +#define CK_MOE_STAGE1_GEMM_IMPL_INT4(A0DataType, B0DataType, AccDataType, EDataType, CDEElementOp, Nswizzle, isPerTensorQuant, MPerBlock) \ + if (isPerTensorQuant) \ + { \ + if (MPerBlock == 32) \ + ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + else if (MPerBlock == 64) \ + ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + else if (MPerBlock == 128) \ + ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + } \ + else \ + { \ + if (MPerBlock == 32) \ + ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + else if (MPerBlock == 64) \ + ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + else if (MPerBlock == 128) \ + ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + } void ck_moe_stage1(torch::Tensor &hidden_states, // [m, k], input token torch::Tensor &w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) @@ -30,8 +60,8 @@ void ck_moe_stage1(torch::Tensor &hidden_states, // [m, k], input token { const at::cuda::OptionalCUDAGuard device_guard(device_of(out)); at::cuda::getCurrentCUDAStream().stream(); - TORCH_CHECK(hidden_states.dtype() == w1.dtype(), - "Weights and activations should both be same dtype!"); + // TORCH_CHECK(hidden_states.dtype() == w1.dtype(), + // "Weights and activations should both be same dtype!"); TORCH_CHECK(out.dtype() == at::ScalarType::BFloat16 || out.dtype() == at::ScalarType::Half, "Out dtype only support BFloat16/Float16!") @@ -40,20 +70,22 @@ void ck_moe_stage1(torch::Tensor &hidden_states, // [m, k], input token int sorted_size = sorted_token_ids.size(0); int E = w1.size(0); int N = w1.size(1); - int K = w1.size(2); + int K = hidden_states.size(-1); // int max_num_tokens_padded = sorted_token_ids.size(0); // int agvtokens_per_expert = max_num_tokens_padded / E; int MPerBlock = block_m.value(); + bool isPerTensorQuant = (!w1_scale.has_value()) || (w1_scale.value().numel() == E); + // int M = agvtokens_per_expert < 32 ? 32 : (agvtokens_per_expert < 64 ? 64 : 128); void *hidden_states_ptr = hidden_states.data_ptr(); - void *w1_ptr = w1.data_ptr(); + void *w1_ptr = w1.transpose(1, 2).data_ptr(); void *w2_ptr = w2.data_ptr(); void *sorted_token_ids_ptr = sorted_token_ids.data_ptr(); void *sorted_expert_ids_ptr = sorted_expert_ids.data_ptr(); void *num_valid_ids_ptr = num_valid_ids.data_ptr(); void *out_ptr = out.data_ptr(); - void *w1_scale_ptr = w1_scale.has_value() ? w1_scale.value().data_ptr() : nullptr; + void *w1_scale_ptr = w1_scale.has_value() ? w1_scale.value().transpose(0, 1).data_ptr() : nullptr; void *a1_scale_ptr = a1_scale.has_value() ? a1_scale.value().data_ptr() : nullptr; // BF16 @@ -64,7 +96,8 @@ void ck_moe_stage1(torch::Tensor &hidden_states, // [m, k], input token using AccDataType = F32; using EDataType = B16; using CDEElementOp = TypeCast; - CK_MOE_STAGE1_GEMM_IMPL(A0DataType, B0DataType, AccDataType, EDataType, CDEElementOp, MPerBlock, false); + const bool Nswizzle = false; + CK_MOE_STAGE1_GEMM_IMPL(A0DataType, B0DataType, AccDataType, EDataType, CDEElementOp, Nswizzle, isPerTensorQuant, MPerBlock); } // FP16 else if (hidden_states.dtype() == at::ScalarType::Half) @@ -74,7 +107,29 @@ void ck_moe_stage1(torch::Tensor &hidden_states, // [m, k], input token using AccDataType = F32; using EDataType = F16; using CDEElementOp = TypeCast; - CK_MOE_STAGE1_GEMM_IMPL(A0DataType, B0DataType, AccDataType, EDataType, CDEElementOp, MPerBlock, false); + const bool Nswizzle = false; + CK_MOE_STAGE1_GEMM_IMPL(A0DataType, B0DataType, AccDataType, EDataType, CDEElementOp, Nswizzle, isPerTensorQuant, MPerBlock); + } + // FP8 Wint4 + else if (hidden_states.dtype() == at::ScalarType::Float8_e4m3fnuz && w1.dtype() == at::ScalarType::UInt32) + { + using A0DataType = F8; + using B0DataType = I4; + const bool Nswizzle = false; + TORCH_CHECK(a1_scale.has_value() && w1_scale.has_value(), + "MoE Quant must input scale!"); + TORCH_CHECK(a1_scale.value().dtype() == at::ScalarType::Float, + "Scales must be Float dtype!"); + using AccDataType = F32; + using CDEElementOp = MulABScaleWint4; + if (out.dtype() == at::ScalarType::Half) + { + CK_MOE_STAGE1_GEMM_IMPL_INT4(A0DataType, B0DataType, AccDataType, F16, CDEElementOp, Nswizzle, isPerTensorQuant, MPerBlock); + } + else if (out.dtype() == at::ScalarType::BFloat16) + { + CK_MOE_STAGE1_GEMM_IMPL_INT4(A0DataType, B0DataType, AccDataType, B16, CDEElementOp, Nswizzle, isPerTensorQuant, MPerBlock); + } } // FP8 else if (hidden_states.dtype() == at::ScalarType::Float8_e4m3fnuz) @@ -87,27 +142,14 @@ void ck_moe_stage1(torch::Tensor &hidden_states, // [m, k], input token "Scales must be Float dtype!"); using AccDataType = F32; using CDEElementOp = MulABScale; + const bool Nswizzle = false; if (out.dtype() == at::ScalarType::Half) { - // if (N % 8 == 0 && N > 8192) - // { - // CK_MOE_STAGE1_GEMM_IMPL(A0DataType, B0DataType, AccDataType, F16, CDEElementOp, MPerBlock, true); - // } - // else - { - CK_MOE_STAGE1_GEMM_IMPL(A0DataType, B0DataType, AccDataType, F16, CDEElementOp, MPerBlock, false); - } + CK_MOE_STAGE1_GEMM_IMPL(A0DataType, B0DataType, AccDataType, F16, CDEElementOp, Nswizzle, isPerTensorQuant, MPerBlock); } else if (out.dtype() == at::ScalarType::BFloat16) { - // if (N % 8 == 0 && N > 8192) - // { - // CK_MOE_STAGE1_GEMM_IMPL(A0DataType, B0DataType, AccDataType, B16, CDEElementOp, MPerBlock, true); - // } - // else - { - CK_MOE_STAGE1_GEMM_IMPL(A0DataType, B0DataType, AccDataType, B16, CDEElementOp, MPerBlock, false); - } + CK_MOE_STAGE1_GEMM_IMPL(A0DataType, B0DataType, AccDataType, B16, CDEElementOp, Nswizzle, isPerTensorQuant, MPerBlock); } } // // I8 @@ -132,13 +174,45 @@ void ck_moe_stage1(torch::Tensor &hidden_states, // [m, k], input token // } } -#define CK_MOE_STAGE2_GEMM_IMPL(A0DataType, B0DataType, AccDataType, EDataType, CDEElementOp, MPerBlock) \ - if (MPerBlock == 32) \ - ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ - else if (MPerBlock == 64) \ - ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ - else if (MPerBlock == 128) \ - ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); +#define CK_MOE_STAGE2_GEMM_IMPL(A0DataType, B0DataType, AccDataType, EDataType, CDEElementOp, Nswizzle, isPerTensorQuant, MPerBlock) \ + if (isPerTensorQuant) \ + { \ + if (MPerBlock == 32) \ + ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ + else if (MPerBlock == 64) \ + ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ + else if (MPerBlock == 128) \ + ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ + } \ + else \ + { \ + if (MPerBlock == 32) \ + ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ + else if (MPerBlock == 64) \ + ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ + else if (MPerBlock == 128) \ + ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ + } + +#define CK_MOE_STAGE2_GEMM_IMPL_INT4(A0DataType, B0DataType, AccDataType, EDataType, CDEElementOp, Nswizzle, isPerTensorQuant, MPerBlock) \ + if (isPerTensorQuant) \ + { \ + if (MPerBlock == 32) \ + ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ + else if (MPerBlock == 64) \ + ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ + else if (MPerBlock == 128) \ + ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ + } \ + else \ + { \ + if (MPerBlock == 32) \ + ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ + else if (MPerBlock == 64) \ + ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ + else if (MPerBlock == 128) \ + ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ + } void ck_moe_stage2(torch::Tensor &inter_states, // [m, k], input token torch::Tensor &w1, // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) @@ -153,21 +227,22 @@ void ck_moe_stage2(torch::Tensor &inter_states, // [m, k], input token std::optional a2_scale = std::nullopt, // [m, 1], token scale std::optional block_m = 32) { - TORCH_CHECK(inter_states.dtype() == w2.dtype(), - "Weights and activations should both be same dtype!"); - + // TORCH_CHECK(inter_states.dtype() == w2.dtype(), + // "Weights and activations should both be same dtype!"); + // TORCH_CHECK(out.dtype() == at::ScalarType::BFloat16 || out.dtype() == at::ScalarType::Half, "Out dtype only support BFloat16/Float16!") int tokens = inter_states.size(0); int sorted_size = sorted_token_ids.size(0); int E = w1.size(0); - int N = w1.size(2); - int K = w2.size(2); + int N = w2.size(1); + int K = inter_states.size(-1); // int max_num_tokens_padded = sorted_token_ids.size(0); // int agvtokens_per_expert = max_num_tokens_padded / E; int MPerBlock = block_m.value(); // int M = agvtokens_per_expert < 32 ? 32 : (agvtokens_per_expert < 64 ? 64 : 128); + bool isPerTensorQuant = (!w2_scale.has_value()) || (w2_scale.value().numel() == E); void *inter_states_ptr = inter_states.data_ptr(); void *w1_ptr = w1.data_ptr(); @@ -188,7 +263,8 @@ void ck_moe_stage2(torch::Tensor &inter_states, // [m, k], input token using AccDataType = F32; using EDataType = B16; using CDEElementOp = TypeCastExpertWeight; - CK_MOE_STAGE2_GEMM_IMPL(A0DataType, B0DataType, AccDataType, EDataType, CDEElementOp, MPerBlock); + const bool Nswizzle = false; + CK_MOE_STAGE2_GEMM_IMPL(A0DataType, B0DataType, AccDataType, EDataType, CDEElementOp, Nswizzle, isPerTensorQuant, MPerBlock); } // FP16 else if (inter_states.dtype() == at::ScalarType::Half) @@ -198,7 +274,29 @@ void ck_moe_stage2(torch::Tensor &inter_states, // [m, k], input token using AccDataType = F32; using EDataType = F16; using CDEElementOp = TypeCastExpertWeight; - CK_MOE_STAGE2_GEMM_IMPL(A0DataType, B0DataType, AccDataType, EDataType, CDEElementOp, MPerBlock); + const bool Nswizzle = false; + CK_MOE_STAGE2_GEMM_IMPL(A0DataType, B0DataType, AccDataType, EDataType, CDEElementOp, Nswizzle, isPerTensorQuant, MPerBlock); + } + // FP8 wint4 + else if (inter_states.dtype() == at::ScalarType::Float8_e4m3fnuz && w1.dtype() == at::ScalarType::UInt32) + { + using A0DataType = F8; + using B0DataType = I4; + const bool Nswizzle = false; + TORCH_CHECK(a2_scale.has_value() && w2_scale.has_value(), + "MoE Quant must input scale!"); + TORCH_CHECK(a2_scale.value().dtype() == at::ScalarType::Float, + "Scales must be Float dtype!"); + using AccDataType = F32; + using CDEElementOp = MulABScaleExpertWeightWin4; + if (out.dtype() == at::ScalarType::Half) + { + CK_MOE_STAGE2_GEMM_IMPL_INT4(A0DataType, B0DataType, AccDataType, F16, CDEElementOp, Nswizzle, isPerTensorQuant, MPerBlock); + } + else if (out.dtype() == at::ScalarType::BFloat16) + { + CK_MOE_STAGE2_GEMM_IMPL_INT4(A0DataType, B0DataType, AccDataType, B16, CDEElementOp, Nswizzle, isPerTensorQuant, MPerBlock); + } } // FP8 else if (inter_states.dtype() == at::ScalarType::Float8_e4m3fnuz) @@ -211,13 +309,14 @@ void ck_moe_stage2(torch::Tensor &inter_states, // [m, k], input token "Scales must be Float dtype!"); using AccDataType = F32; using CDEElementOp = MulABScaleExpertWeight; + const bool Nswizzle = false; if (out.dtype() == at::ScalarType::Half) { - CK_MOE_STAGE2_GEMM_IMPL(A0DataType, B0DataType, AccDataType, F16, CDEElementOp, MPerBlock); + CK_MOE_STAGE2_GEMM_IMPL(A0DataType, B0DataType, AccDataType, F16, CDEElementOp, Nswizzle, isPerTensorQuant, MPerBlock); } else if (out.dtype() == at::ScalarType::BFloat16) { - CK_MOE_STAGE2_GEMM_IMPL(A0DataType, B0DataType, AccDataType, B16, CDEElementOp, MPerBlock); + CK_MOE_STAGE2_GEMM_IMPL(A0DataType, B0DataType, AccDataType, B16, CDEElementOp, Nswizzle, isPerTensorQuant, MPerBlock); } } // // I8 diff --git a/csrc/py_itfs_cu/asm_fmoe.cpp b/csrc/py_itfs_cu/asm_fmoe.cpp index fbd1ba0d..45b1b0ef 100644 --- a/csrc/py_itfs_cu/asm_fmoe.cpp +++ b/csrc/py_itfs_cu/asm_fmoe.cpp @@ -76,7 +76,7 @@ class FMoeKernel FMoeKernel(const char *name, const char *hsaco, uint32_t sub_GU = 512) { const char *AITER_ASM_DIR = std::getenv("AITER_ASM_DIR"); - std::cout << "hipModuleLoad: " << (std::string(AITER_ASM_DIR) + hsaco).c_str() << " GetFunction: " << name; + std::cout << "[aiter] hipModuleLoad: " << (std::string(AITER_ASM_DIR) + hsaco).c_str() << " GetFunction: " << name; HIP_CALL(hipModuleLoad(&module, (std::string(AITER_ASM_DIR) + hsaco).c_str())); HIP_CALL(hipModuleGetFunction(&kernel_func, module, name)); std::cout << " Success" << std::endl; diff --git a/csrc/py_itfs_cu/asm_moe_2stage.cpp b/csrc/py_itfs_cu/asm_moe_2stage.cpp new file mode 100644 index 00000000..d79268af --- /dev/null +++ b/csrc/py_itfs_cu/asm_moe_2stage.cpp @@ -0,0 +1,253 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#include +#include +#include +#include +#include +#include "aiter_hip_common.h" + +struct __attribute__((packed)) KernelArgs +{ + void *ptr_O; + p2 _p0; + void *ptr_X; + p2 _p1; + void *ptr_GU; + p2 _p2; + void *ptr_XC; + p2 _p3; + void *ptr_XQ; + p2 _p4; + void *ptr_GUQ; + p2 _p5; + void *ptr_SMQ; + p2 _p6; + void *ptr_STP; + p2 _p7; + void *ptr_SEP; + p2 _p8; + unsigned int dim; + p3 _p9; + unsigned int hidden_dim; + p3 _p10; + unsigned int token_cnt; + p3 _p11; + unsigned int eprt_cnt; + p3 _p12; + unsigned int Xs; + p3 _p13; + unsigned int GUs; + p3 _p14; + unsigned int Os; + p3 _p15; + unsigned int eGUs; + p3 _p16; + unsigned int eGUQs; + p3 _p17; + unsigned int eSMQs; + p3 _p18; + unsigned int topk; + p3 _p19; +}; + +struct FMoe2StageConfig +{ + std::string name; + std::string co_name; + int tile_M; + int tile_N; +}; + +#define ADD_CFG(M, N, path, name) \ + { \ + name, { name, path name ".co", M, N } \ + } +using CFG = std::unordered_map; + +CFG *get_cfg(torch::Tensor &inp, torch::Tensor &out, torch::Tensor &w1) +{ + if (inp.scalar_type() == at::ScalarType::Float8_e4m3fnuz && + w1.scalar_type() == at::ScalarType::Float8_e4m3fnuz && + out.scalar_type() == at::ScalarType::BFloat16) + { + static CFG cfg_fmoe_stage1_bf16_pertokenFp8_g1u1 = { + ADD_CFG(16, 64, "fmoe_2stages/", "fmoe_stage1_bf16_pertokenFp8_g1u1_16x64"), + ADD_CFG(16, 64, "fmoe_2stages/", "fmoe_stage1_bf16_pertokenFp8_g1u1_16x64_pf2"), + + ADD_CFG(16, 512, "fmoe_2stages/", "fmoe_stage1_bf16_pertokenFp8_g1u1_16x512_pf2"), + + ADD_CFG(32, 64, "fmoe_2stages/", "fmoe_stage1_bf16_pertokenFp8_g1u1_32x64"), + + ADD_CFG(32, 128, "fmoe_2stages/", "fmoe_stage1_bf16_pertokenFp8_g1u1_32x128"), + ADD_CFG(32, 128, "fmoe_2stages/", "fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_2tg_pf2"), + ADD_CFG(32, 128, "fmoe_2stages/", "fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_2tg_pf3"), + ADD_CFG(32, 128, "fmoe_2stages/", "fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_3tg_pf2"), + + ADD_CFG(32, 512, "fmoe_2stages/", "fmoe_stage1_bf16_pertokenFp8_g1u1_32x512_pf2"), + + ADD_CFG(48, 128, "fmoe_2stages/", "fmoe_stage1_bf16_pertokenFp8_g1u1_48x128"), + + ADD_CFG(128, 128, "fmoe_2stages/", "fmoe_stage1_bf16_pertokenFp8_g1u1_128x128"), + ADD_CFG(128, 128, "fmoe_2stages/", "fmoe_stage1_bf16_pertokenFp8_g1u1_128x128_pf2"), + + ADD_CFG(160, 128, "fmoe_2stages/", "fmoe_stage1_bf16_pertokenFp8_g1u1_160x128_pf2"), + }; + return &cfg_fmoe_stage1_bf16_pertokenFp8_g1u1; + } + else + { + TORCH_CHECK(false, "Unsupported input_type:", inp.scalar_type(), ", out_type:", out.scalar_type()); + } +}; +std::string get_heuristic_kernel(int m_num, int N, int blockk_size, CFG *cfgs) +{ + hipDevice_t dev; + hipDeviceProp_t dev_prop; + HIP_CALL(hipGetDevice(&dev)); + HIP_CALL(hipGetDeviceProperties(&dev_prop, dev)); + uint32_t num_cu = dev_prop.multiProcessorCount; + uint32_t empty_cu = num_cu; + uint32_t tg_num = 0; + uint32_t round = 0xffffffff; + std::string selected; + + for (const auto &el : *cfgs) + { + const auto &cfg = el.second; + if (cfg.tile_M != blockk_size) + { + continue; + } + + tg_num = (N + cfg.tile_N - 1) / cfg.tile_N * m_num; + uint32_t local_round = (tg_num + num_cu - 1) / num_cu; + if (local_round < round) + { + round = local_round; + selected = el.first; + empty_cu = local_round * num_cu - tg_num; + } + else if (local_round == round) + { + if (empty_cu > (local_round * num_cu - tg_num)) + { + round = local_round; + selected = el.first; + empty_cu = local_round * num_cu - tg_num; + } + } + } + return selected; +} +void moe_stage1_fp8_g1u1( + torch::Tensor &input, // [token_cnt, model_dim] M,K + torch::Tensor &w1, // [expert, inter_dim*2, model_dim] N,K + torch::Tensor &w2, // [expert, model_dim, inter_dim] + torch::Tensor &sorted_token_ids, // [max_num_tokens_padded] + torch::Tensor &sorted_weight_buf, // [max_num_tokens_padded] + torch::Tensor &sorted_expert_ids, // [max_num_m_blocks] + torch::Tensor &num_valid_ids, // [1] + torch::Tensor &out, // [token_cnt, topk, inter_dim] + std::string &kernelName, + int block_size, + std::optional a1_scale = std::nullopt, // [token_cnt, 1], token scale + std::optional w1_scale = std::nullopt // [expert, 1, inter_dim], gate(up) scale +) +{ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + CFG *config_map = get_cfg(input, out, w1); + static std::unordered_map> impl_ptr_map; + int hidden_dim = out.size(2); + int sub_X_cnt = sorted_expert_ids.size(0); + if (kernelName.empty()) + { + kernelName = get_heuristic_kernel(sub_X_cnt, hidden_dim, block_size, config_map); + } + + AiterAsmKernel *impl_ptr = nullptr; + auto it = config_map->find(kernelName); + if (it != config_map->end()) + { + const auto &cfg = it->second; + const char *name = cfg.name.c_str(); + const char *co_name = cfg.co_name.c_str(); + + auto result = impl_ptr_map.emplace(name, nullptr); + if (result.second) + { + result.first->second = std::make_unique(name, co_name); + } + impl_ptr = result.first->second.get(); + } + else + TORCH_CHECK(false, __func__, " Unsupported " + kernelName); + + int token_cnt = out.size(0); + int topk = out.size(1); + + // const char *enable_vskip = std::getenv("AITER_ENABLE_VSKIP"); + + int dim = w2.size(1); + int eprt = w1.size(0); + const auto &cfg = it->second; + uint32_t sub_GU = cfg.tile_N; + TORCH_CHECK(block_size == cfg.tile_M, __func__, "need make sure block_size == cfg.tile_M"); + + int stride_X = input.stride(0) * input.element_size(); + int stride_GU = dim * w1.element_size(); + + int stride_expert_GU = stride_GU * hidden_dim; + int stride_expert_GUDQN = w1_scale.has_value() ? w1_scale.value().stride(0) * sizeof(float) : 0; + int stride_expert_SMTDQN = hidden_dim * sizeof(float); + int stride_O = hidden_dim * out.element_size() * topk; + if (hidden_dim * 2 == w1.size(1)) + { + stride_expert_GU *= 2; + } + + KernelArgs args; + size_t arg_size = sizeof(args); + args.ptr_O = out.data_ptr(); + args.ptr_X = input.data_ptr(); + args.ptr_GU = w1.data_ptr(); + args.ptr_XC = num_valid_ids.data_ptr(); + + args.ptr_XQ = a1_scale.has_value() ? a1_scale.value().data_ptr() : nullptr; + args.ptr_GUQ = w1_scale.has_value() ? w1_scale.value().data_ptr() : nullptr; + // args.ptr_SMQ = w2_smooth_qnt.has_value() ? w2_smooth_qnt.value().data_ptr() : nullptr; + + args.ptr_STP = sorted_token_ids.data_ptr(); + args.ptr_SEP = sorted_expert_ids.data_ptr(); + args.dim = dim; + args.hidden_dim = hidden_dim; + args.token_cnt = token_cnt; + args.eprt_cnt = eprt; + args.Xs = stride_X; + args.GUs = stride_GU; + args.Os = stride_O; + args.eGUs = stride_expert_GU; + args.eGUQs = stride_expert_GUDQN; + args.eSMQs = stride_expert_SMTDQN; + args.topk = topk; + + void *config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, &args, HIP_LAUNCH_PARAM_BUFFER_SIZE, + &arg_size, HIP_LAUNCH_PARAM_END}; + + int bdx = 256; + int gdx = ((hidden_dim + sub_GU - 1) / sub_GU); + int gdy = sub_X_cnt; + int gdz = 1; + + impl_ptr->launch_kernel({&args, + &arg_size, + gdx, // gdx + gdy, // gdy + gdz, // gdz + bdx, // bdx: 4 wv64 + 1, // bdy + 1, // bdz + stream}); +} diff --git a/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_128x128.co b/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_128x128.co new file mode 100755 index 00000000..ef2f3baf Binary files /dev/null and b/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_128x128.co differ diff --git a/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_128x128_pf2.co b/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_128x128_pf2.co new file mode 100755 index 00000000..3f9b25e1 Binary files /dev/null and b/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_128x128_pf2.co differ diff --git a/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_160x128_pf2.co b/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_160x128_pf2.co new file mode 100755 index 00000000..e8a69269 Binary files /dev/null and b/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_160x128_pf2.co differ diff --git a/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_16x512_pf2.co b/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_16x512_pf2.co new file mode 100755 index 00000000..3d9a426a Binary files /dev/null and b/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_16x512_pf2.co differ diff --git a/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_16x64.co b/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_16x64.co new file mode 100755 index 00000000..e640b5c8 Binary files /dev/null and b/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_16x64.co differ diff --git a/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_16x64_pf2.co b/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_16x64_pf2.co new file mode 100755 index 00000000..2cb426e9 Binary files /dev/null and b/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_16x64_pf2.co differ diff --git a/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_32x128.co b/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_32x128.co new file mode 100755 index 00000000..0c92d85b Binary files /dev/null and b/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_32x128.co differ diff --git a/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_2tg_pf2.co b/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_2tg_pf2.co new file mode 100755 index 00000000..f1e09e8d Binary files /dev/null and b/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_2tg_pf2.co differ diff --git a/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_2tg_pf3.co b/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_2tg_pf3.co new file mode 100755 index 00000000..b1316463 Binary files /dev/null and b/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_2tg_pf3.co differ diff --git a/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_3tg_pf2.co b/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_3tg_pf2.co new file mode 100755 index 00000000..c414dfd0 Binary files /dev/null and b/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_3tg_pf2.co differ diff --git a/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_32x512_pf2.co b/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_32x512_pf2.co new file mode 100755 index 00000000..d9c43b1f Binary files /dev/null and b/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_32x512_pf2.co differ diff --git a/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_32x64.co b/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_32x64.co new file mode 100755 index 00000000..d210b3d6 Binary files /dev/null and b/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_32x64.co differ diff --git a/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_48x128.co b/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_48x128.co new file mode 100755 index 00000000..2e682597 Binary files /dev/null and b/hsa/fmoe_2stages/fmoe_stage1_bf16_pertokenFp8_g1u1_48x128.co differ diff --git a/hsa/fmoe_2stages/tune.py b/hsa/fmoe_2stages/tune.py new file mode 100644 index 00000000..ff0adaaf --- /dev/null +++ b/hsa/fmoe_2stages/tune.py @@ -0,0 +1,337 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +import torch +import torch.nn.functional as F +import aiter +import pandas as pd +import argparse +from aiter.fused_moe_bf16_asm import ( + fused_topk, + moe_sorting_ck, +) +from aiter.ops.shuffle import shuffle_weight +from aiter.utility.mp_tuner import mp_tuner +from aiter.test_common import checkAllclose +from aiter import QuantType + +torch.set_default_device("cuda") + + +def asm_stage1( + a1_qt, + w1_qt, + w2_qt, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + out, + kernelName, + blockM, + a1_scale, + w1_scale, +): + aiter.moe_stage1_fp8_g1u1( + a1_qt, + w1_qt, + w2_qt, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + out, + kernelName, + blockM, + a1_scale, + w1_scale, + ) + return out + + +def ck_stage1( + input, + w1, + w2, + sorted_ids, + sorted_expert_ids, + num_valid_ids, + out, + topk, + w1_scale, + a1_scale, + blockM, + token, + dtype, +): + tmp = torch.empty( + (token, topk, w1.shape[1]), + dtype=dtype, + device=input.device, + ) + aiter.ck_moe_stage1( + input, + w1, + w2, + sorted_ids, + sorted_expert_ids, + num_valid_ids, + tmp, + topk, + w1_scale, + a1_scale, + blockM, + ) + aiter.silu_and_mul(out, tmp) + return out + + +def torch_moe_stage1( + hidden_states, + w1, # E, inter_dim*2, model_dim + w2, # E, model_dim, inter_dim + topk_weight, + topk_ids, + dtype=torch.float16, + # following for quant + fc1_scale=None, # [expert, inter_dim, 1] + w1_scale=None, # [1] + a1_scale=None, # [expert]] + block_size=32, +): + ctype = torch.float # compute type + hidden_states = hidden_states.to(ctype) + w1 = w1.to(ctype) + + B, D = hidden_states.shape + topk = topk_weight.shape[1] + N = w1.shape[1] + num_experts, model_dim, inter_dim = w2.shape + + max_num_tokens_padded = topk_ids.numel() + num_experts * block_size - topk + + # gose to quant D_w8a8/w8a8 + if fc1_scale is not None: + w1 = (w1.view(-1, D) * fc1_scale.view(-1, 1)).view(num_experts, -1, D) + if a1_scale is not None and w1_scale is not None: + hidden_states = hidden_states * a1_scale + w1 = w1 * w1_scale.view(num_experts, -1, 1) + + hidden_states = hidden_states.view(B, -1, D).repeat(1, topk, 1) + + out = torch.zeros( + (B, topk, N), + dtype=ctype, + device=hidden_states.device, + ) + for E_id in range(w1.shape[0]): + mask = topk_ids == E_id + if mask.sum(): + sub_tokens = hidden_states[mask] + act_input = sub_tokens @ (w1[E_id].transpose(0, 1)) + out[mask] = act_input + + return out.to(dtype) + + +def go( + untunedf, + tunedf, +): + blockMs = [16, 32, 48, 64, 128, 160] + asm_kernels = { + 16: [ + "fmoe_stage1_bf16_pertokenFp8_g1u1_16x64", + "fmoe_stage1_bf16_pertokenFp8_g1u1_16x64_pf2", + "fmoe_stage1_bf16_pertokenFp8_g1u1_16x512_pf2", + ], + 32: [ + "fmoe_stage1_bf16_pertokenFp8_g1u1_32x64", + "fmoe_stage1_bf16_pertokenFp8_g1u1_32x128", + "fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_2tg_pf2", + "fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_2tg_pf3", + "fmoe_stage1_bf16_pertokenFp8_g1u1_32x128_3tg_pf2", + "fmoe_stage1_bf16_pertokenFp8_g1u1_32x512_pf2", + ], + 48: ["fmoe_stage1_bf16_pertokenFp8_g1u1_48x128"], + 128: [ + "fmoe_stage1_bf16_pertokenFp8_g1u1_128x128", + "fmoe_stage1_bf16_pertokenFp8_g1u1_128x128_pf2", + ], + 160: ["fmoe_stage1_bf16_pertokenFp8_g1u1_160x128_pf2"], + } + args = [ + "token", + "model_dim", + "inter_dim", + "expert", + "topk", + "dtype", + "q_dtype", + "q_type", + "use_g1u1", + ] + print(untunedf[args]) + prorfiles = [] + bests = [] + for line in untunedf[args].values: + token, model_dim, inter_dim, expert, topk, dtype, q_dtype, q_type, use_g1u1 = ( + line + ) + dtype = eval(dtype) + q_dtype = eval(q_dtype) + q_type = eval(q_type) + torch_quant = aiter.get_torch_quant(q_type) + input = torch.randn((token, model_dim), dtype=dtype) + if use_g1u1: + w1 = torch.randn((expert, inter_dim * 2, model_dim), dtype=dtype) / 10 + else: + w1 = torch.randn((expert, inter_dim, model_dim), dtype=dtype) + w2 = torch.randn((expert, model_dim, inter_dim), dtype=dtype) + + score = torch.randn((token, expert), dtype=dtype) + topk_weights, topk_ids = fused_topk(input, score, topk, True) + w1_qt, w1_scale = torch_quant(w1, quant_dtype=q_dtype) + w2_qt, w2_scale = torch_quant(w2, quant_dtype=q_dtype) + w1_qt = w1_qt.view(w1.shape) + w2_qt = w2_qt.view(w2.shape) + a1_qt, a1_scale = torch_quant(input, quant_dtype=q_dtype) + + out1_ref = torch_moe_stage1( + a1_qt, + w1_qt, + w2_qt, + topk_weights, + topk_ids, + dtype=dtype, + fc1_scale=None, + w1_scale=w1_scale, + a1_scale=a1_scale, + ) + gate, up = out1_ref.split([inter_dim, inter_dim], dim=-1) + ref = F.silu(gate) * up + + tasks = [] + tasks_ck = [] + for blockM in blockMs: + sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_buf = ( + moe_sorting_ck(topk_ids, topk_weights, expert, model_dim, dtype, blockM) + ) + out = torch.empty( + (token, topk, inter_dim), + dtype=dtype, + ) + for el in asm_kernels.get(blockM, []): + tasks.append( + ( + el, # tag + asm_stage1, # func + ( + a1_qt, + shuffle_weight(w1_qt, (16, 16)), + shuffle_weight(w2_qt, (16, 16)), + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + out, + el, + blockM, + a1_scale, + w1_scale, + ), + ) + ) + + if blockM in [32, 64, 128]: + tasks_ck.append( + ( + f"ck_{blockM}", # tag + ck_stage1, # func + ( + a1_qt, + shuffle_weight(w1_qt, layout=(32, 32)), + w2_qt, + sorted_ids, + sorted_expert_ids, + num_valid_ids, + out, + topk, + w1_scale, + a1_scale, + blockM, + token, + dtype, + ), + ) + ) + rets = mp_tuner(tasks + tasks_ck) + + profileDF = [] + for tag, us, _ in rets: + err = checkAllclose( + ref.to("cpu"), _, msg=f"[{tag:<50}]: {us:.2f}us ...... " + ) + profileDF.append( + [ + token, + model_dim, + inter_dim, + expert, + topk, + dtype, + q_dtype, + q_type, + use_g1u1, + us, + tag, + f"{err:.1%}", + ] + ) + profileDF = pd.DataFrame(profileDF, columns=args + ["us", "tag", "err"]) + best_one = profileDF.loc[profileDF["us"].idxmin()] + prorfiles.append(profileDF) + bests.append(best_one) + return pd.concat(prorfiles), pd.concat(bests, axis=1).T + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-i", + "--untune_file", + default="aiter/configs/untuned_fmoe.csv", + required=False, + help="input", + ) + + parser.add_argument( + "-o", + "--tune_file", + default="aiter/configs/tuned_fmoe.csv", + required=False, + help="output: tuning result store this file", + ) + parser.add_argument( + "-o2", + "--profile_file", + default="aiter/configs/profile_fmoe.csv", + required=False, + help="output: tuning result store this file", + ) + + parser.add_argument( + "--sort", + action="store_true", + required=False, + help="Arranged according to the B M N K size", + ) + + args = parser.parse_args() + untunedf = pd.read_csv(args.untune_file) + tunedf = None + # tunedf = pd.read_csv(args.tune_file) + profiles, tunedf = go(untunedf, tunedf) + tunedf.to_csv(args.tune_file, index=False) + profiles.to_csv(args.profile_file, index=False) diff --git a/op_tests/test_kvcache.py b/op_tests/test_kvcache.py index c8cd5bb5..68f7aa82 100644 --- a/op_tests/test_kvcache.py +++ b/op_tests/test_kvcache.py @@ -26,7 +26,6 @@ def run_torch(key, value, k_cache, v_cache, slot_mapping, block_size, x, asm_lay k_scale = quantCfg['k_scale'] v_scale = quantCfg['v_scale'] key, k_scale_ = aiter.pertoken_quant(key, - y_scale_dtype=quantCfg['y_scale_dtype'], quant_dtype=quantCfg['quant_dtype']) k_scale_ = k_scale_.permute(0, 1, 3, 2).view( num_batch*num_tokens, num_heads).contiguous() @@ -49,7 +48,6 @@ def run_torch(key, value, k_cache, v_cache, slot_mapping, block_size, x, asm_lay if quantCfg: value, v_scale_ = aiter.pertoken_quant(value, - y_scale_dtype=quantCfg['y_scale_dtype'], quant_dtype=quantCfg['quant_dtype']) v_scale_ = v_scale_.permute(0, 1, 3, 2).view( num_batch*num_tokens, num_heads).contiguous() @@ -131,8 +129,7 @@ def test_reshape_and_cache(ctx_lens: int, k_cache = torch.empty(k_cache_shape, dtype=DTyoe_KVCache, device=device) v_cache = torch.empty(v_cache_shape, dtype=DTyoe_KVCache, device=device) if quantCfg: - k_scale = torch.empty(kv_scale_shape, - dtype=quantCfg['y_scale_dtype'], device=key.device) + k_scale = torch.empty(kv_scale_shape, device=key.device) v_scale = torch.empty_like(k_scale) quantCfg['k_scale'] = k_scale.clone() quantCfg['v_scale'] = v_scale.clone() @@ -203,13 +200,10 @@ def test_reshape_and_cache(ctx_lens: int, torch.bfloat16, torch.bfloat16) print('\nstart quant fp16->fp8') test_reshape_and_cache(4097, 128, (8, 1), 128, 16, - torch.float16, torch.float8_e4m3fnuz, quantCfg={'y_scale_dtype': torch.float, - 'quant_dtype': torch.float8_e4m3fnuz}) + torch.float16, torch.float8_e4m3fnuz, quantCfg={'quant_dtype': torch.float8_e4m3fnuz}) print('\nstart quant fp16->i8') test_reshape_and_cache(4097, 128, (8, 1), 128, 16, - torch.float16, torch.int8, quantCfg={'y_scale_dtype': torch.float, - 'quant_dtype': torch.int8}) + torch.float16, torch.int8, quantCfg={'quant_dtype': torch.int8}) print('\nstart quant bf16->i8') test_reshape_and_cache(4097, 128, (8, 1), 128, 16, - torch.bfloat16, torch.int8, quantCfg={'y_scale_dtype': torch.float, - 'quant_dtype': torch.int8}) + torch.bfloat16, torch.int8, quantCfg={'quant_dtype': torch.int8}) diff --git a/op_tests/test_moe_2stage.py b/op_tests/test_moe_2stage.py index 50f6ce55..77c5f15c 100644 --- a/op_tests/test_moe_2stage.py +++ b/op_tests/test_moe_2stage.py @@ -9,25 +9,36 @@ import os from typing import Any, Callable, Dict, Optional, Tuple import aiter -from aiter.test_common import checkAllclose, perftest -from aiter import pertoken_quant -from aiter.fused_moe_gelu import fused_topk -from aiter.fused_moe_bf16_asm import asm_moe, torch_moe, moe_sorting_ck, ck_moe_2stages +from aiter.test_common import checkAllclose, perftest, benchmark +from op_tests.int4_utils import * + +from aiter.fused_moe_bf16_asm import ( + fused_topk, + asm_moe, + torch_moe, + moe_sorting_ck, + ck_moe_2stages, +) from aiter.ops.shuffle import shuffle_weight +from aiter import ActivationType + +# torch.int4 = torch.uint32 @perftest(num_iters=3) -def torch_moe_stage1(hidden_states, - w1, # E, inter_dim*2, model_dim - w2, # E, model_dim, inter_dim - topk_weight, topk_ids, - dtype=torch.float16, - # following for quant - fc1_scale=None, # [expert, inter_dim, 1] - w1_scale=None, # [1] - a1_scale=None, # [expert]] - block_size=32 - ): +def torch_moe_stage1( + hidden_states, + w1, # E, inter_dim*2, model_dim + w2, # E, model_dim, inter_dim + topk_weight, + topk_ids, + dtype=torch.float16, + # following for quant + fc1_scale=None, # [expert, inter_dim, 1] + w1_scale=None, # [1] + a1_scale=None, # [expert]] + block_size=32, +): ctype = torch.float # compute type hidden_states = hidden_states.to(ctype) w1 = w1.to(ctype) @@ -36,8 +47,6 @@ def torch_moe_stage1(hidden_states, topk = topk_weight.shape[1] N = w1.shape[1] num_experts, model_dim, inter_dim = w2.shape - hidden_states = hidden_states.view( - B, -1, D).repeat(1, topk, 1) max_num_tokens_padded = topk_ids.numel() + num_experts * block_size - topk @@ -46,7 +55,9 @@ def torch_moe_stage1(hidden_states, w1 = (w1.view(-1, D) * fc1_scale.view(-1, 1)).view(num_experts, -1, D) if a1_scale is not None and w1_scale is not None: hidden_states = hidden_states * a1_scale - w1 = w1 * w1_scale.view(-1, 1, 1) + w1 = w1 * w1_scale.view(w1_scale.shape[0], -1, 1) + + hidden_states = hidden_states.view(B, -1, D).repeat(1, topk, 1) out = torch.zeros( (B, topk, N), @@ -64,17 +75,21 @@ def torch_moe_stage1(hidden_states, @perftest(num_iters=3) -def torch_moe_stage2(hidden_states, - w1, # E, inter_dim*2, model_dim - w2, # E, model_dim, inter_dim - topk_weights, topk_ids, - sorted_weights, sorted_ids, - sorted_expert_ids, num_valid_ids, - dtype=torch.float16, - w2_scale=None, # [1] - a2_scale=None, # [expert]] - block_size=32 - ): +def torch_moe_stage2( + hidden_states, + w1, # E, inter_dim*2, model_dim + w2, # E, model_dim, inter_dim + topk_weights, + topk_ids, + sorted_weights, + sorted_ids, + sorted_expert_ids, + num_valid_ids, + dtype=torch.float16, + w2_scale=None, # [1] + a2_scale=None, # [expert]] + block_size=32, +): ctype = torch.float # compute type hidden_states = hidden_states.to(ctype) w2 = w2.to(ctype) @@ -83,11 +98,12 @@ def torch_moe_stage2(hidden_states, # M, _ = hidden_states.shape num_experts, model_dim, inter_dim = w2.shape max_num_m_blocks = sorted_expert_ids.shape[0] + hidden_states = hidden_states.view(token_num, topk, inter_dim) # gose to quant D_w8a8/w8a8 if a2_scale is not None and w2_scale is not None: - hidden_states = hidden_states * a2_scale - w2 = w2 * w2_scale.view(-1, 1, 1) + hidden_states = hidden_states * a2_scale.view(a2_scale.shape[0], -1, 1) + w2 = w2 * w2_scale.view(num_experts, -1, 1) out = torch.zeros( (token_num, topk, model_dim), @@ -103,106 +119,67 @@ def torch_moe_stage2(hidden_states, return (out * topk_weights.view(token_num, -1, 1)).sum(1).to(dtype) -def torch_moe(hidden_states, w1, w2, topk_weight, topk_ids, - # following for quant - fc1_scale=None, # [expert, inter_dim, 1] - fc2_scale=None, # [expert, model_dim, 1] - fc1_smooth_scale=None, # [expert, 1, model_dim] - fc2_smooth_scale=None, # [expert, 1, inter_dim] - ): - B, D = hidden_states.shape - topk = topk_weight.shape[1] - dtype = hidden_states.dtype - hidden_states = hidden_states.view( - B, -1, D).repeat(1, topk, 1) - out = torch.zeros( - (B, topk, D), - dtype=dtype, - device=hidden_states.device, - ) - # g1u1(w1 include gate and up) - if w2.shape[2]*2 == w1.shape[1]: - moeType = "g1u1" - inter_dim = w2.shape[2] - # g1u0(w1 only include gate) - else: - moeType = "g1u0" - inter_dim = w1.shape[1] - # gose to quant D_w8a8/w8a8 - if fc1_scale is not None: - expert = w1.shape[0] - w2D = w2.shape[-1] - w1 = (w1.view(-1, D).to(fc1_scale) * - fc1_scale.view(-1, 1)).to(dtype).view(expert, -1, D) - w2 = (w2.view(-1, w2D).to(fc2_scale) * - fc2_scale.view(-1, 1)).to(dtype).view(expert, -1, w2D) - if fc1_smooth_scale is not None: - expert = fc1_smooth_scale.shape[0] - fc1_smooth_scale = fc1_smooth_scale.view(expert, -1).to(dtype) - fc2_smooth_scale = fc2_smooth_scale.view(expert, -1).to(dtype) - - for E_id in range(w1.shape[0]): - mask = topk_ids == E_id - if mask.sum(): - sub_tokens = hidden_states[mask] - if fc1_smooth_scale is not None: - sub_tokens = sub_tokens * ( - fc1_smooth_scale[E_id]) - act_input = sub_tokens @ (w1[E_id].transpose(0, 1)) - if moeType == "g1u1": - gate, up = act_input.split([inter_dim, inter_dim], dim=-1) - act_out = F.silu(gate) * up - else: - act_out = F.gelu(act_input) - if fc2_smooth_scale is not None: - act_out = act_out * ( - fc2_smooth_scale[E_id]) - out[mask] = act_out @ (w2[E_id].transpose(0, 1)) - - return ( - out * topk_weight.view(B, -1, 1).to(out.dtype) - ).sum(dim=1) - - @perftest() -def ck_moe_stage1(hidden_states, - w1, # [E, inter_dim*2, model_dim] - w2, # [E, model_dim, inter_dim] - sorted_token_ids, # [max_num_tokens_padded] - sorted_expert_ids, # [max_num_m_blocks] - num_valid_ids, # [1] - w1_scale, a1_scale, dtype, - topk, - block_size=32 - ): +def ck_moe_stage1( + hidden_states, + w1, # [E, inter_dim*2, model_dim] + w2, # [E, model_dim, inter_dim] + sorted_token_ids, # [max_num_tokens_padded] + sorted_expert_ids, # [max_num_m_blocks] + num_valid_ids, # [1] + w1_scale, + a1_scale, + dtype, + topk, + block_size=32, +): token_num = hidden_states.shape[0] D = w1.shape[1] - num_experts, model_dim, inter_dim = w2.shape + # num_experts, model_dim, inter_dim = w2.shape max_num_tokens_padded = sorted_token_ids.shape[0] # max_num_tokens_padded = sorted_expert_ids.shape[0]*block_size - out = torch.zeros( + out = torch.empty( (token_num, topk, D), dtype=dtype, device=hidden_states.device, ) - aiter.ck_moe_stage1(hidden_states, w1, w2, sorted_token_ids, - sorted_expert_ids, num_valid_ids, out, topk, w1_scale, a1_scale, block_size) + aiter.ck_moe_stage1( + hidden_states, + w1, + w2, + sorted_token_ids, + sorted_expert_ids, + num_valid_ids, + out, + topk, + w1_scale, + a1_scale, + block_size, + ) + tmp = torch.empty( + (token_num, topk, int(D / 2)), dtype=dtype, device=hidden_states.device + ) + aiter.silu_and_mul(tmp, out) + out = tmp return out @perftest() -def ck_moe_stage2(hidden_states, - w1, # [E, inter_dim*2, model_dim] - w2, # [E, model_dim, inter_dim] - sorted_token_ids, # [max_num_tokens_padded] - sorted_expert_ids, # [max_num_m_blocks] - sorted_weights, # [max_num_tokens_padded] - num_valid_ids, # [1] - w2_scale, a2_scale, dtype, - topk, - block_size=32 - ): +def ck_moe_stage2( + hidden_states, + w1, # [E, inter_dim*2, model_dim] + w2, # [E, model_dim, inter_dim] + sorted_token_ids, # [max_num_tokens_padded] + sorted_expert_ids, # [max_num_m_blocks] + sorted_weights, # [max_num_tokens_padded] + num_valid_ids, # [1] + w2_scale, + a2_scale, + dtype, + topk, + block_size=32, +): token_num = hidden_states.shape[0] D = w2.shape[1] num_experts, model_dim, inter_dim = w2.shape @@ -214,152 +191,310 @@ def ck_moe_stage2(hidden_states, dtype=dtype, device=hidden_states.device, ) - aiter.ck_moe_stage2(hidden_states, w1, w2, sorted_token_ids, - sorted_expert_ids, sorted_weights, - num_valid_ids, out, topk, w2_scale, a2_scale, block_size) + aiter.ck_moe_stage2( + hidden_states, + w1, + w2, + sorted_token_ids, + sorted_expert_ids, + sorted_weights, + num_valid_ids, + out, + topk, + w2_scale, + a2_scale, + block_size, + ) return out @perftest() -def ck_moe_fused_2stages(hidden_states, - # [expert(local_expert:EP), inter_dim(*2), dim] N,K - w1, - w2, # [expert(local_expert:EP), dim, inter_dim] - topk_weight, topk_ids, - # following for int8 quant - # [expert(local_expert:EP), inter_dim, 1] - fc1_scale=None, - # [expert(local_expert:EP), model_dim, 1] - fc2_scale=None, - block_size=32, - a1_scale=None - ): - return ck_moe_2stages(hidden_states, w1, w2, topk_weight, topk_ids, - fc1_scale, fc2_scale, block_size=block_size, a1_scale=a1_scale) - - -def test_fmoe(dtype, token, model_dim, inter_dim, E, topk, quant='No', use_g1u1=False, shared_E=0): +def ck_moe_fused_2stages( + hidden_states, + # [expert(local_expert:EP), inter_dim(*2), dim] N,K + w1, + w2, # [expert(local_expert:EP), dim, inter_dim] + topk_weight, + topk_ids, + # following for int8 quant + # [expert(local_expert:EP), inter_dim, 1] + fc1_scale=None, + # [expert(local_expert:EP), model_dim, 1] + fc2_scale=None, + block_size=32, + a1_scale=None, +): + return ck_moe_2stages( + hidden_states, + w1, + w2, + topk_weight, + topk_ids, + fc1_scale, + fc2_scale, + block_size=block_size, + a1_scale=a1_scale, + ) + + +@perftest() +def asm_moe_stage1( + hidden_states, + w1, # [E, inter_dim*2, model_dim] + w2, # [E, model_dim, inter_dim] + sorted_token_ids, # [max_num_tokens_padded] + sorted_weights, + sorted_expert_ids, # [max_num_m_blocks] + num_valid_ids, # [1] + w1_scale, + a1_scale, + dtype, + topk, + block_size=128, +): + + token_num = hidden_states.shape[0] + D = w1.shape[1] + D = int(D / 2) + + out = torch.empty( + (token_num, topk, D), + dtype=dtype, + device=hidden_states.device, + ) + + aiter.moe_stage1_fp8_g1u1( + hidden_states, + w1, + w2, + sorted_token_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + out, + "", + block_size, + a1_scale, + w1_scale, + ) + return out + + +@benchmark() +def test_fmoe( + dtype, + token, + model_dim, + inter_dim, + E, + topk, + qType, + AQDType, + WQDType, + BLOCK_SIZE_M, + use_g1u1=False, +): + torch_quant = aiter.get_torch_quant(qType) input = torch.randn((token, model_dim), dtype=dtype, device="cuda") if use_g1u1: - w1 = torch.randn((E+shared_E, inter_dim*2, model_dim), - dtype=dtype, device="cuda") / 10 + w1 = torch.randn((E, inter_dim * 2, model_dim), dtype=dtype, device="cuda") else: - w1 = torch.randn((E+shared_E, inter_dim, model_dim), - dtype=dtype, device="cuda") - w2 = torch.randn((E+shared_E, model_dim, inter_dim), - dtype=dtype, device="cuda") + w1 = torch.randn((E, inter_dim, model_dim), dtype=dtype, device="cuda") + w2 = torch.randn((E, model_dim, inter_dim), dtype=dtype, device="cuda") + score = torch.randn((token, E), device="cuda", dtype=dtype) topk_weights, topk_ids = fused_topk(input, score, topk, True) - E, model_dim, inter_dim = w2.shape - M, topk = topk_ids.shape - BLOCK_SIZE_M = 128 - sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_buf = moe_sorting_ck(topk_ids, topk_weights, E, - model_dim, dtype, BLOCK_SIZE_M) - - quant_dtype = torch.float8_e4m3fnuz - w1_qt, w1_scale = aiter.pertoken_quant(w1.view(E, -1), - quant_dtype=quant_dtype) - w2_qt, w2_scale = aiter.pertoken_quant(w2.view(E, -1), - quant_dtype=quant_dtype) - w1_qt = w1_qt.view(w1.shape) - w2_qt = w2_qt.view(w2.shape) - - a1_qt, a1_scale = aiter.per_tensor_quant(input, quant_dtype=quant_dtype) - # a1_qt, a1_scale = aiter.per_tensor_quant_fp8_hip(input) - - out1_ref, us_ref = torch_moe_stage1(a1_qt, w1_qt, - w2_qt, - topk_weights, topk_ids, - dtype=dtype, - fc1_scale=None, - w1_scale=w1_scale, - a1_scale=a1_scale, - block_size=BLOCK_SIZE_M) + sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_buf = ( + moe_sorting_ck(topk_ids, topk_weights, E, model_dim, dtype, BLOCK_SIZE_M) + ) + if qType == aiter.QuantType.per_Tensor: + w1_qt, w1_scale = aiter.pertoken_quant(w1.view(E, -1), quant_dtype=WQDType) + w2_qt, w2_scale = aiter.pertoken_quant(w2.view(E, -1), quant_dtype=WQDType) + elif qType == aiter.QuantType.per_Token and WQDType == torch.int4: # int4 w quant + w1_qt, w1_scale = aiter.pertoken_quant(w1, quant_dtype=torch.int8, dtypeMax=7) + w2_qt, w2_scale = aiter.pertoken_quant(w2, quant_dtype=torch.int8, dtypeMax=7) + else: + w1_qt, w1_scale = torch_quant(w1, quant_dtype=WQDType) + w2_qt, w2_scale = torch_quant(w2, quant_dtype=WQDType) + w1_qt = w1_qt_aiter = w1_qt.view(w1.shape) + w2_qt = w2_qt_aiter = w2_qt.view(w2.shape) + + a1_qt, a1_scale = torch_quant(input, quant_dtype=AQDType) + + out1_ref, us_ref = torch_moe_stage1( + a1_qt, + w1_qt, + w2_qt, + topk_weights, + topk_ids, + dtype=dtype, + fc1_scale=None, + w1_scale=w1_scale, + a1_scale=a1_scale, + block_size=BLOCK_SIZE_M, + ) if use_g1u1: gate, up = out1_ref.split([inter_dim, inter_dim], dim=-1) - input2 = F.silu(gate) * up + out1_ref = F.silu(gate) * up else: - input2 = F.gelu(out1_ref) - a2_qt, a2_scale = aiter.per_tensor_quant(input2, quant_dtype=quant_dtype) - # a2_qt, a2_scale = aiter.per_tensor_quant_fp8_hip(input2) - out2_ref, us_ref = torch_moe_stage2(a2_qt, - w1_qt, # E, inter_dim*2, model_dim - w2_qt, # E, model_dim, inter_dim - topk_weights, topk_ids, - sorted_weights, sorted_ids, - sorted_expert_ids, num_valid_ids, - dtype=dtype, - # [expert, inter_dim, 1] - w2_scale=w2_scale, - a2_scale=a2_scale, - block_size=BLOCK_SIZE_M - ) - - out_ref = torch_moe(input, w1, w2, topk_weights, topk_ids) + out1_ref = F.gelu(out1_ref) + # out_ref = torch_moe(input, w1, w2, topk_weights, topk_ids) # checkAllclose(out_ref, out2_ref, msg="[torch] 1_stage vs 2_stage") + if WQDType == torch.int4: # int4 w quant + w1_qt_aiter = rearrange_4bit_elements(convert_int8_to_uint32_int4(w1_qt_aiter)) + w2_qt_aiter = rearrange_4bit_elements(convert_int8_to_uint32_int4(w2_qt_aiter)) + w1_qt_aiter = shuffle_weight(w1_qt_aiter, layout=(32, 32)) + w2_qt_aiter = shuffle_weight(w2_qt_aiter, layout=(32, 32)) + + out1_ck, us = ck_moe_stage1( + a1_qt, + w1_qt_aiter, + w2_qt_aiter, + sorted_ids, + sorted_expert_ids, + num_valid_ids, + w1_scale, + a1_scale, + dtype, + topk, + BLOCK_SIZE_M, + ) - out1_qt, us = ck_moe_stage1(a1_qt, - shuffle_weight(w1_qt, layout=(32, 32)), - w2, - sorted_ids, - sorted_expert_ids, - num_valid_ids, - w1_scale, a1_scale, - dtype, topk, BLOCK_SIZE_M) - checkAllclose(out1_ref, out1_qt, - msg=f'ck_moe_stage1:{us:.2f} us, {token*model_dim*inter_dim*topk*2/us/1000/1000:.2f} tflops......(quant:{quant_dtype})') + checkAllclose( + out1_ref, + out1_ck, + msg=f"[perf] ck_moe_stage1:{us:.2f} us, {token*model_dim*inter_dim*topk*2/us/1000/1000:.2f} tflops......(quant:{AQDType})", + ) - if use_g1u1: - gate, up = out1_qt.split([inter_dim, inter_dim], dim=-1) - input2 = F.silu(gate) * up - else: - input2 = F.gelu(out1_qt) - # a2_qt, a2_scale = aiter.per_tensor_quant_fp8_hip(input2) - a2_qt, a2_scale = aiter.per_tensor_quant(input2, quant_dtype=quant_dtype) - - out2_qt, us = ck_moe_stage2(a2_qt, - w1_qt, - shuffle_weight(w2_qt, layout=(32, 32)), - sorted_ids, - sorted_expert_ids, - sorted_weights, - num_valid_ids, - w2_scale, a2_scale, - dtype, topk, BLOCK_SIZE_M) - checkAllclose(out2_ref, out2_qt, - msg=f'ck_moe_stage2:{us:.2f} us, {token*model_dim*inter_dim*topk*2/us/1000/1000:.2f} tflops......(quant:{quant_dtype})') - - out_ck_qt, us = ck_moe_fused_2stages(input, - shuffle_weight( - w1_qt, layout=(32, 32)), - shuffle_weight( - w2_qt, layout=(32, 32)), - topk_weights, topk_ids, - w1_scale, w2_scale, - # block_size=BLOCK_SIZE_M - ) - - checkAllclose(out2_ref, out_ck_qt, - msg=f'ck_moe_fused_2stages:{us:.2f} us, {token*model_dim*inter_dim*topk*2/us/1000/1000:.2f} tflops......(quant:{quant_dtype})') - - out_ck_nqt, us = ck_moe_fused_2stages(input, - shuffle_weight(w1, layout=(32, 32)), - shuffle_weight(w2, layout=(32, 32)), - topk_weights, topk_ids, - None, None, - # block_size=BLOCK_SIZE_M - ) - - checkAllclose(out_ref, out_ck_nqt, - msg=f'ck_moe_fused_2stages:{us:.2f} us, {token*model_dim*inter_dim*topk*2/us/1000/1000:.2f} tflops......(No quant)') - - -for dtype in [torch.float16]: - for m in [32, 128]: - for dim in [8192]: - for inter_dim in [6144, 16384]: + if WQDType != torch.int4: + # asm int4 2 stage not support yet + if qType == aiter.QuantType.per_Tensor: + a1_scale = a1_scale.view(1).repeat(token) + w1_scale = w1_scale.view(E, 1).repeat(1, w1.shape[-2]) + out1_asm, us = asm_moe_stage1( + a1_qt, + shuffle_weight(w1_qt, (16, 16)), + shuffle_weight(w2_qt, (16, 16)), + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + w1_scale, + a1_scale, + dtype, + topk, + BLOCK_SIZE_M, + ) + checkAllclose( + out1_ref, + out1_asm, + msg=f"[perf] asm_moe_stage1:{us:.2f} us, {token*model_dim*inter_dim*topk*2/us/1000/1000:.2f} tflops......(quant:{AQDType})", + ) + + # ######################## stage 2 start ########### + if qType == aiter.QuantType.per_Token: + out1_ref = out1_ref.view(token, -1) + a2_qt, a2_scale = torch_quant(out1_ref, quant_dtype=AQDType) + out2_ref, us_ref = torch_moe_stage2( + a2_qt, + w1_qt, # E, inter_dim*2, model_dim + w2_qt, # E, model_dim, inter_dim + topk_weights, + topk_ids, + sorted_weights, + sorted_ids, + sorted_expert_ids, + num_valid_ids, + dtype=dtype, + # [expert, inter_dim, 1] + w2_scale=w2_scale, + a2_scale=a2_scale, + block_size=BLOCK_SIZE_M, + ) + + if qType == aiter.QuantType.per_Token: + out1_ck = out1_ck.view(token, -1) + a2_qt, a2_scale = torch_quant(out1_ck, quant_dtype=AQDType) + a2_qt = a2_qt.view(token, topk, -1) + out2_ck, us = ck_moe_stage2( + a2_qt, + w1_qt_aiter, + w2_qt_aiter, + sorted_ids, + sorted_expert_ids, + sorted_weights, + num_valid_ids, + w2_scale, + a2_scale, + dtype, + topk, + BLOCK_SIZE_M, + ) + checkAllclose( + out2_ref, + out2_ck, + msg=f"ck_moe_stage2:{us:.2f} us, {token*model_dim*inter_dim*topk*2/us/1000/1000:.2f} tflops......(quant:{AQDType})", + ) + # # ######################## stage 2 end ########### + + +# per Token quant/a8w8 +for dtype in [torch.bfloat16]: + for m in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 1536, 2048, 3072, 4096][-1:]: + for dim in [6144]: + for inter_dim in [4096]: + expert, topk = 8, 2 + test_fmoe( + dtype, + m, + dim, + inter_dim, + expert, + topk, + aiter.QuantType.per_Token, + torch.float8_e4m3fnuz, + torch.float8_e4m3fnuz, + BLOCK_SIZE_M=128, + use_g1u1=True, + ) + +# per Tensor quant/a8w8 +for dtype in [torch.bfloat16]: + for m in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 1536, 2048, 3072, 4096]: + for dim in [6144]: + for inter_dim in [4096]: + expert, topk = 8, 2 + test_fmoe( + dtype, + m, + dim, + inter_dim, + expert, + topk, + aiter.QuantType.per_Tensor, + torch.float8_e4m3fnuz, + torch.float8_e4m3fnuz, + BLOCK_SIZE_M=32, + use_g1u1=True, + ) +# per Tensor quant/a8w4 +for dtype in [torch.bfloat16]: + for m in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 1536, 2048, 3072, 4096]: + for dim in [6144]: + for inter_dim in [4096]: expert, topk = 8, 2 - test_fmoe(dtype, m, dim, inter_dim, expert, topk, - quant='fp8quant', use_g1u1=True) + test_fmoe( + dtype, + m, + dim, + inter_dim, + expert, + topk, + aiter.QuantType.per_Token, + torch.float8_e4m3fnuz, + torch.int4, + BLOCK_SIZE_M=32, + use_g1u1=True, + ) diff --git a/op_tests/test_smoothquant.py b/op_tests/test_smoothquant.py index 5d1c59fb..74daef33 100644 --- a/op_tests/test_smoothquant.py +++ b/op_tests/test_smoothquant.py @@ -12,7 +12,7 @@ @perftest() def run_torch(input, x_scale, y_scale_dtype=torch.float32): output, y_scale = aiter.pertoken_quant( - input, x_scale=x_scale, y_scale_dtype=y_scale_dtype) + input, x_scale=x_scale, scale_dtype=y_scale_dtype) return output, y_scale