diff --git a/3rdparty/composable_kernel b/3rdparty/composable_kernel index c9bcfd75..353a612b 160000 --- a/3rdparty/composable_kernel +++ b/3rdparty/composable_kernel @@ -1 +1 @@ -Subproject commit c9bcfd755ed4d2102d76a6f545ac6e9a030d7d8e +Subproject commit 353a612b44a3dac232f5a6b2c4430dab071b3692 diff --git a/README.md b/README.md index adb2893e..5f3ae6e2 100644 --- a/README.md +++ b/README.md @@ -28,3 +28,6 @@ there are number of op test, you can run them like this: `python3 op_tests/test_ |GEMM | D=AxB+C | |FusedMoE | bf16 balabala | |WIP | coming soon... | + +## Ops +1. [INT8/FP8 A8W8 Per-Tensor/Rowwise Scaling GEMM](csrc/ck_gemm_a8w8/README.md) diff --git a/aiter/configs/a8w8_gemm_model_config/Llama-3.1-70B-Instruct-TP1.csv b/aiter/configs/a8w8_gemm_model_config/Llama-3.1-70B-Instruct-TP1.csv new file mode 100644 index 00000000..031c2055 --- /dev/null +++ b/aiter/configs/a8w8_gemm_model_config/Llama-3.1-70B-Instruct-TP1.csv @@ -0,0 +1,189 @@ +M,N,K +1,8192,8192 +1,10240,8192 +1,8192,28672 +1,57344,8192 +2,8192,8192 +2,8192,28672 +2,10240,8192 +2,57344,8192 +4,10240,8192 +4,8192,8192 +4,57344,8192 +4,8192,28672 +8,8192,8192 +8,8192,28672 +8,57344,8192 +8,10240,8192 +16,8192,8192 +16,8192,28672 +16,57344,8192 +16,10240,8192 +17,8192,28672 +17,10240,8192 +17,57344,8192 +17,8192,8192 +24,57344,8192 +24,8192,8192 +24,8192,28672 +24,10240,8192 +32,8192,28672 +32,57344,8192 +32,10240,8192 +32,8192,8192 +40,8192,8192 +40,8192,28672 +40,10240,8192 +40,57344,8192 +48,10240,8192 +48,8192,8192 +48,8192,28672 +48,57344,8192 +56,8192,8192 +56,8192,28672 +56,57344,8192 +56,10240,8192 +64,8192,8192 +64,8192,28672 +64,57344,8192 +64,10240,8192 +72,8192,8192 +72,8192,28672 +72,57344,8192 +72,10240,8192 +80,8192,28672 +80,57344,8192 +80,10240,8192 +80,8192,8192 +88,8192,8192 +88,8192,28672 +88,10240,8192 +88,57344,8192 +96,8192,8192 +96,8192,28672 +96,57344,8192 +96,10240,8192 +104,8192,28672 +104,57344,8192 +104,10240,8192 +104,8192,8192 +112,8192,8192 +112,8192,28672 +112,57344,8192 +112,10240,8192 +120,10240,8192 +120,8192,8192 +120,8192,28672 +120,57344,8192 +128,57344,8192 +128,10240,8192 +128,8192,8192 +128,8192,28672 +136,10240,8192 +136,8192,8192 +136,8192,28672 +136,57344,8192 +144,10240,8192 +144,8192,8192 +144,8192,28672 +144,57344,8192 +152,10240,8192 +152,8192,8192 +152,8192,28672 +152,57344,8192 +160,57344,8192 +160,10240,8192 +160,8192,8192 +160,8192,28672 +168,8192,8192 +168,8192,28672 +168,57344,8192 +168,10240,8192 +176,10240,8192 +176,8192,8192 +176,8192,28672 +176,57344,8192 +184,10240,8192 +184,8192,8192 +184,8192,28672 +184,57344,8192 +192,8192,8192 +192,8192,28672 +192,57344,8192 +192,10240,8192 +200,8192,28672 +200,57344,8192 +200,10240,8192 +200,8192,8192 +208,57344,8192 +208,10240,8192 +208,8192,8192 +208,8192,28672 +216,57344,8192 +216,10240,8192 +216,8192,8192 +216,8192,28672 +224,10240,8192 +224,8192,8192 +224,8192,28672 +224,57344,8192 +232,10240,8192 +232,8192,8192 +232,8192,28672 +232,57344,8192 +240,10240,8192 +240,8192,8192 +240,8192,28672 +240,57344,8192 +248,8192,28672 +248,57344,8192 +248,10240,8192 +248,8192,8192 +256,8192,28672 +256,10240,8192 +256,57344,8192 +256,8192,8192 +512,10240,8192 +512,8192,8192 +512,8192,28672 +512,57344,8192 +1024,10240,8192 +1024,8192,8192 +1024,8192,28672 +1024,57344,8192 +1536,8192,8192 +1536,8192,28672 +1536,57344,8192 +1536,10240,8192 +2048,8192,28672 +2048,57344,8192 +2048,8192,8192 +2048,10240,8192 +3072,8192,28672 +3072,10240,8192 +3072,8192,8192 +3072,57344,8192 +4096,10240,8192 +4096,8192,8192 +4096,8192,28672 +4096,57344,8192 +8192,10240,8192 +8192,8192,8192 +8192,8192,28672 +8192,57344,8192 +16384,8192,28672 +16384,57344,8192 +16384,10240,8192 +16384,8192,8192 +18432,8192,8192 +18432,8192,28672 +18432,57344,8192 +18432,10240,8192 +20480,10240,8192 +20480,8192,8192 +20480,8192,28672 +20480,57344,8192 +32768,57344,8192 +32768,10240,8192 +32768,8192,8192 +32768,8192,28672 diff --git a/aiter/configs/a8w8_gemm_model_config/Llama-3.1-70B-Instruct-TP2.csv b/aiter/configs/a8w8_gemm_model_config/Llama-3.1-70B-Instruct-TP2.csv new file mode 100644 index 00000000..a1abdcda --- /dev/null +++ b/aiter/configs/a8w8_gemm_model_config/Llama-3.1-70B-Instruct-TP2.csv @@ -0,0 +1,189 @@ +M,N,K +1,5120,8192 +1,8192,14336 +1,8192,4096 +1,28672,8192 +2,8192,4096 +2,5120,8192 +2,8192,14336 +2,28672,8192 +4,5120,8192 +4,8192,14336 +4,8192,4096 +4,28672,8192 +8,8192,4096 +8,8192,14336 +8,28672,8192 +8,5120,8192 +16,8192,14336 +16,5120,8192 +16,28672,8192 +16,8192,4096 +17,8192,14336 +17,28672,8192 +17,8192,4096 +17,5120,8192 +24,28672,8192 +24,8192,4096 +24,5120,8192 +24,8192,14336 +32,8192,4096 +32,5120,8192 +32,28672,8192 +32,8192,14336 +40,8192,4096 +40,5120,8192 +40,8192,14336 +40,28672,8192 +48,5120,8192 +48,8192,4096 +48,28672,8192 +48,8192,14336 +56,8192,4096 +56,8192,14336 +56,28672,8192 +56,5120,8192 +64,5120,8192 +64,28672,8192 +64,8192,14336 +64,8192,4096 +72,8192,14336 +72,8192,4096 +72,28672,8192 +72,5120,8192 +80,8192,4096 +80,5120,8192 +80,28672,8192 +80,8192,14336 +88,8192,4096 +88,5120,8192 +88,28672,8192 +88,8192,14336 +96,5120,8192 +96,8192,14336 +96,8192,4096 +96,28672,8192 +104,8192,14336 +104,8192,4096 +104,28672,8192 +104,5120,8192 +112,28672,8192 +112,5120,8192 +112,8192,14336 +112,8192,4096 +120,5120,8192 +120,8192,4096 +120,28672,8192 +120,8192,14336 +128,5120,8192 +128,28672,8192 +128,8192,4096 +128,8192,14336 +136,5120,8192 +136,8192,14336 +136,8192,4096 +136,28672,8192 +144,5120,8192 +144,8192,14336 +144,8192,4096 +144,28672,8192 +152,8192,4096 +152,28672,8192 +152,8192,14336 +152,5120,8192 +160,8192,4096 +160,5120,8192 +160,28672,8192 +160,8192,14336 +168,8192,14336 +168,5120,8192 +168,8192,4096 +168,28672,8192 +176,5120,8192 +176,28672,8192 +176,8192,4096 +176,8192,14336 +184,28672,8192 +184,5120,8192 +184,8192,14336 +184,8192,4096 +192,8192,4096 +192,5120,8192 +192,28672,8192 +192,8192,14336 +200,8192,14336 +200,8192,4096 +200,5120,8192 +200,28672,8192 +208,5120,8192 +208,8192,14336 +208,8192,4096 +208,28672,8192 +216,8192,14336 +216,5120,8192 +216,8192,4096 +216,28672,8192 +224,5120,8192 +224,28672,8192 +224,8192,4096 +224,8192,14336 +232,8192,4096 +232,5120,8192 +232,28672,8192 +232,8192,14336 +240,8192,4096 +240,28672,8192 +240,8192,14336 +240,5120,8192 +248,8192,4096 +248,28672,8192 +248,5120,8192 +248,8192,14336 +256,8192,14336 +256,28672,8192 +256,8192,4096 +256,5120,8192 +512,8192,4096 +512,5120,8192 +512,28672,8192 +512,8192,14336 +1024,5120,8192 +1024,8192,4096 +1024,28672,8192 +1024,8192,14336 +1536,8192,4096 +1536,28672,8192 +1536,8192,14336 +1536,5120,8192 +2048,28672,8192 +2048,5120,8192 +2048,8192,14336 +2048,8192,4096 +3072,5120,8192 +3072,8192,14336 +3072,8192,4096 +3072,28672,8192 +4096,5120,8192 +4096,8192,14336 +4096,8192,4096 +4096,28672,8192 +8192,5120,8192 +8192,28672,8192 +8192,8192,4096 +8192,8192,14336 +16384,8192,4096 +16384,5120,8192 +16384,28672,8192 +16384,8192,14336 +18432,8192,4096 +18432,8192,14336 +18432,5120,8192 +18432,28672,8192 +20480,8192,4096 +20480,5120,8192 +20480,28672,8192 +20480,8192,14336 +32768,5120,8192 +32768,28672,8192 +32768,8192,4096 +32768,8192,14336 diff --git a/aiter/configs/a8w8_gemm_model_config/Llama-3.1-70B-Instruct-TP4.csv b/aiter/configs/a8w8_gemm_model_config/Llama-3.1-70B-Instruct-TP4.csv new file mode 100644 index 00000000..232b0412 --- /dev/null +++ b/aiter/configs/a8w8_gemm_model_config/Llama-3.1-70B-Instruct-TP4.csv @@ -0,0 +1,189 @@ +M,N,K +1,2560,8192 +1,14336,8192 +1,8192,2048 +1,8192,7168 +2,2560,8192 +2,8192,7168 +2,8192,2048 +2,14336,8192 +4,14336,8192 +4,8192,2048 +4,8192,7168 +4,2560,8192 +8,8192,7168 +8,2560,8192 +8,8192,2048 +8,14336,8192 +16,2560,8192 +16,8192,2048 +16,14336,8192 +16,8192,7168 +17,2560,8192 +17,14336,8192 +17,8192,7168 +17,8192,2048 +24,14336,8192 +24,2560,8192 +24,8192,7168 +24,8192,2048 +32,8192,7168 +32,2560,8192 +32,14336,8192 +32,8192,2048 +40,8192,7168 +40,8192,2048 +40,2560,8192 +40,14336,8192 +48,14336,8192 +48,8192,7168 +48,8192,2048 +48,2560,8192 +56,8192,7168 +56,2560,8192 +56,8192,2048 +56,14336,8192 +64,8192,7168 +64,8192,2048 +64,14336,8192 +64,2560,8192 +72,8192,2048 +72,8192,7168 +72,2560,8192 +72,14336,8192 +80,8192,7168 +80,2560,8192 +80,14336,8192 +80,8192,2048 +88,8192,7168 +88,8192,2048 +88,14336,8192 +88,2560,8192 +96,8192,2048 +96,8192,7168 +96,2560,8192 +96,14336,8192 +104,8192,7168 +104,2560,8192 +104,14336,8192 +104,8192,2048 +112,8192,2048 +112,2560,8192 +112,14336,8192 +112,8192,7168 +120,14336,8192 +120,8192,7168 +120,8192,2048 +120,2560,8192 +128,8192,7168 +128,8192,2048 +128,14336,8192 +128,2560,8192 +136,2560,8192 +136,14336,8192 +136,8192,2048 +136,8192,7168 +144,8192,2048 +144,8192,7168 +144,2560,8192 +144,14336,8192 +152,8192,7168 +152,8192,2048 +152,2560,8192 +152,14336,8192 +160,2560,8192 +160,14336,8192 +160,8192,7168 +160,8192,2048 +168,2560,8192 +168,14336,8192 +168,8192,7168 +168,8192,2048 +176,8192,7168 +176,2560,8192 +176,8192,2048 +176,14336,8192 +184,8192,7168 +184,8192,2048 +184,2560,8192 +184,14336,8192 +192,8192,7168 +192,8192,2048 +192,14336,8192 +192,2560,8192 +200,8192,7168 +200,2560,8192 +200,14336,8192 +200,8192,2048 +208,2560,8192 +208,14336,8192 +208,8192,2048 +208,8192,7168 +216,14336,8192 +216,8192,7168 +216,8192,2048 +216,2560,8192 +224,8192,7168 +224,2560,8192 +224,8192,2048 +224,14336,8192 +232,14336,8192 +232,8192,7168 +232,8192,2048 +232,2560,8192 +240,8192,7168 +240,2560,8192 +240,8192,2048 +240,14336,8192 +248,8192,7168 +248,2560,8192 +248,14336,8192 +248,8192,2048 +256,2560,8192 +256,14336,8192 +256,8192,7168 +256,8192,2048 +512,8192,7168 +512,8192,2048 +512,2560,8192 +512,14336,8192 +1024,14336,8192 +1024,8192,7168 +1024,8192,2048 +1024,2560,8192 +1536,8192,7168 +1536,2560,8192 +1536,8192,2048 +1536,14336,8192 +2048,2560,8192 +2048,14336,8192 +2048,8192,2048 +2048,8192,7168 +3072,2560,8192 +3072,14336,8192 +3072,8192,2048 +3072,8192,7168 +4096,8192,2048 +4096,8192,7168 +4096,2560,8192 +4096,14336,8192 +8192,8192,7168 +8192,8192,2048 +8192,14336,8192 +8192,2560,8192 +16384,8192,7168 +16384,2560,8192 +16384,8192,2048 +16384,14336,8192 +18432,8192,7168 +18432,2560,8192 +18432,8192,2048 +18432,14336,8192 +20480,14336,8192 +20480,8192,7168 +20480,8192,2048 +20480,2560,8192 +32768,14336,8192 +32768,2560,8192 +32768,8192,7168 +32768,8192,2048 diff --git a/aiter/configs/a8w8_gemm_model_config/Llama-3.1-70B-Instruct-TP8.csv b/aiter/configs/a8w8_gemm_model_config/Llama-3.1-70B-Instruct-TP8.csv new file mode 100644 index 00000000..5a4d4fc0 --- /dev/null +++ b/aiter/configs/a8w8_gemm_model_config/Llama-3.1-70B-Instruct-TP8.csv @@ -0,0 +1,189 @@ +M,N,K +1,1280,8192 +1,7168,8192 +1,8192,3584 +1,8192,1024 +2,7168,8192 +2,1280,8192 +2,8192,3584 +2,8192,1024 +4,7168,8192 +4,1280,8192 +4,8192,3584 +4,8192,1024 +8,1280,8192 +8,8192,3584 +8,8192,1024 +8,7168,8192 +16,7168,8192 +16,1280,8192 +16,8192,3584 +16,8192,1024 +17,8192,3584 +17,7168,8192 +17,1280,8192 +17,8192,1024 +24,7168,8192 +24,1280,8192 +24,8192,3584 +24,8192,1024 +32,8192,3584 +32,8192,1024 +32,1280,8192 +32,7168,8192 +40,8192,1024 +40,1280,8192 +40,7168,8192 +40,8192,3584 +48,1280,8192 +48,8192,3584 +48,8192,1024 +48,7168,8192 +56,8192,3584 +56,8192,1024 +56,7168,8192 +56,1280,8192 +64,1280,8192 +64,7168,8192 +64,8192,3584 +64,8192,1024 +72,8192,3584 +72,8192,1024 +72,7168,8192 +72,1280,8192 +80,7168,8192 +80,1280,8192 +80,8192,3584 +80,8192,1024 +88,7168,8192 +88,8192,3584 +88,8192,1024 +88,1280,8192 +96,8192,3584 +96,8192,1024 +96,1280,8192 +96,7168,8192 +104,8192,3584 +104,8192,1024 +104,7168,8192 +104,1280,8192 +112,7168,8192 +112,1280,8192 +112,8192,3584 +112,8192,1024 +120,7168,8192 +120,1280,8192 +120,8192,3584 +120,8192,1024 +128,1280,8192 +128,7168,8192 +128,8192,1024 +128,8192,3584 +136,7168,8192 +136,1280,8192 +136,8192,3584 +136,8192,1024 +144,8192,3584 +144,8192,1024 +144,7168,8192 +144,1280,8192 +152,8192,3584 +152,8192,1024 +152,7168,8192 +152,1280,8192 +160,1280,8192 +160,8192,3584 +160,7168,8192 +160,8192,1024 +168,8192,3584 +168,1280,8192 +168,7168,8192 +168,8192,1024 +176,7168,8192 +176,1280,8192 +176,8192,3584 +176,8192,1024 +184,7168,8192 +184,1280,8192 +184,8192,3584 +184,8192,1024 +192,8192,3584 +192,8192,1024 +192,1280,8192 +192,7168,8192 +200,8192,3584 +200,8192,1024 +200,7168,8192 +200,1280,8192 +208,7168,8192 +208,1280,8192 +208,8192,3584 +208,8192,1024 +216,7168,8192 +216,1280,8192 +216,8192,3584 +216,8192,1024 +224,1280,8192 +224,8192,3584 +224,8192,1024 +224,7168,8192 +232,8192,1024 +232,8192,3584 +232,1280,8192 +232,7168,8192 +240,7168,8192 +240,1280,8192 +240,8192,3584 +240,8192,1024 +248,7168,8192 +248,1280,8192 +248,8192,3584 +248,8192,1024 +256,8192,3584 +256,7168,8192 +256,1280,8192 +256,8192,1024 +512,8192,3584 +512,8192,1024 +512,7168,8192 +512,1280,8192 +1024,1280,8192 +1024,8192,3584 +1024,8192,1024 +1024,7168,8192 +1536,8192,3584 +1536,8192,1024 +1536,7168,8192 +1536,1280,8192 +2048,1280,8192 +2048,7168,8192 +2048,8192,3584 +2048,8192,1024 +3072,7168,8192 +3072,1280,8192 +3072,8192,3584 +3072,8192,1024 +4096,7168,8192 +4096,1280,8192 +4096,8192,3584 +4096,8192,1024 +8192,7168,8192 +8192,8192,1024 +8192,8192,3584 +8192,1280,8192 +16384,8192,3584 +16384,8192,1024 +16384,7168,8192 +16384,1280,8192 +18432,8192,3584 +18432,8192,1024 +18432,1280,8192 +18432,7168,8192 +20480,7168,8192 +20480,8192,1024 +20480,1280,8192 +20480,8192,3584 +32768,7168,8192 +32768,1280,8192 +32768,8192,3584 +32768,8192,1024 diff --git a/aiter/configs/a8w8_gemm_model_config/Llama-3.1-8B-Instruct-TP1.csv b/aiter/configs/a8w8_gemm_model_config/Llama-3.1-8B-Instruct-TP1.csv new file mode 100644 index 00000000..6039ac82 --- /dev/null +++ b/aiter/configs/a8w8_gemm_model_config/Llama-3.1-8B-Instruct-TP1.csv @@ -0,0 +1,189 @@ +M,N,K +1,6144,4096 +1,4096,4096 +1,4096,14336 +1,28672,4096 +2,4096,14336 +2,4096,4096 +2,6144,4096 +2,28672,4096 +4,4096,14336 +4,28672,4096 +4,6144,4096 +4,4096,4096 +8,6144,4096 +8,28672,4096 +8,4096,4096 +8,4096,14336 +16,6144,4096 +16,4096,4096 +16,4096,14336 +16,28672,4096 +17,28672,4096 +17,4096,14336 +17,4096,4096 +17,6144,4096 +24,4096,4096 +24,4096,14336 +24,6144,4096 +24,28672,4096 +32,4096,14336 +32,6144,4096 +32,4096,4096 +32,28672,4096 +40,28672,4096 +40,4096,14336 +40,6144,4096 +40,4096,4096 +48,28672,4096 +48,6144,4096 +48,4096,4096 +48,4096,14336 +56,6144,4096 +56,28672,4096 +56,4096,4096 +56,4096,14336 +64,28672,4096 +64,6144,4096 +64,4096,14336 +64,4096,4096 +72,28672,4096 +72,6144,4096 +72,4096,4096 +72,4096,14336 +80,6144,4096 +80,4096,14336 +80,28672,4096 +80,4096,4096 +88,28672,4096 +88,6144,4096 +88,4096,4096 +88,4096,14336 +96,28672,4096 +96,6144,4096 +96,4096,4096 +96,4096,14336 +104,28672,4096 +104,6144,4096 +104,4096,4096 +104,4096,14336 +112,4096,4096 +112,6144,4096 +112,4096,14336 +112,28672,4096 +120,4096,14336 +120,28672,4096 +120,4096,4096 +120,6144,4096 +128,28672,4096 +128,6144,4096 +128,4096,14336 +128,4096,4096 +136,4096,14336 +136,28672,4096 +136,6144,4096 +136,4096,4096 +144,28672,4096 +144,6144,4096 +144,4096,4096 +144,4096,14336 +152,28672,4096 +152,6144,4096 +152,4096,4096 +152,4096,14336 +160,4096,4096 +160,6144,4096 +160,4096,14336 +160,28672,4096 +168,6144,4096 +168,28672,4096 +168,4096,4096 +168,4096,14336 +176,28672,4096 +176,4096,4096 +176,6144,4096 +176,4096,14336 +184,4096,4096 +184,6144,4096 +184,4096,14336 +184,28672,4096 +192,6144,4096 +192,28672,4096 +192,4096,14336 +192,4096,4096 +200,28672,4096 +200,6144,4096 +200,4096,4096 +200,4096,14336 +208,4096,4096 +208,4096,14336 +208,28672,4096 +208,6144,4096 +216,4096,14336 +216,4096,4096 +216,6144,4096 +216,28672,4096 +224,28672,4096 +224,6144,4096 +224,4096,4096 +224,4096,14336 +232,28672,4096 +232,4096,14336 +232,6144,4096 +232,4096,4096 +240,28672,4096 +240,6144,4096 +240,4096,4096 +240,4096,14336 +248,6144,4096 +248,4096,4096 +248,4096,14336 +248,28672,4096 +256,28672,4096 +256,4096,14336 +256,4096,4096 +256,6144,4096 +512,6144,4096 +512,28672,4096 +512,4096,14336 +512,4096,4096 +1024,4096,4096 +1024,6144,4096 +1024,28672,4096 +1024,4096,14336 +1536,6144,4096 +1536,28672,4096 +1536,4096,14336 +1536,4096,4096 +2048,6144,4096 +2048,28672,4096 +2048,4096,4096 +2048,4096,14336 +3072,4096,4096 +3072,4096,14336 +3072,6144,4096 +3072,28672,4096 +4096,4096,14336 +4096,28672,4096 +4096,6144,4096 +4096,4096,4096 +8192,28672,4096 +8192,4096,14336 +8192,6144,4096 +8192,4096,4096 +16384,4096,4096 +16384,4096,14336 +16384,28672,4096 +16384,6144,4096 +18432,4096,14336 +18432,6144,4096 +18432,4096,4096 +18432,28672,4096 +20480,4096,4096 +20480,4096,14336 +20480,28672,4096 +20480,6144,4096 +32768,28672,4096 +32768,6144,4096 +32768,4096,4096 +32768,4096,14336 diff --git a/aiter/configs/a8w8_gemm_model_config/Llama-3.1-8B-Instruct-TP2.csv b/aiter/configs/a8w8_gemm_model_config/Llama-3.1-8B-Instruct-TP2.csv new file mode 100644 index 00000000..43859492 --- /dev/null +++ b/aiter/configs/a8w8_gemm_model_config/Llama-3.1-8B-Instruct-TP2.csv @@ -0,0 +1,189 @@ +M,N,K +1,14336,4096 +1,4096,7168 +1,4096,2048 +1,3072,4096 +2,4096,7168 +2,4096,2048 +2,14336,4096 +2,3072,4096 +4,14336,4096 +4,3072,4096 +4,4096,7168 +4,4096,2048 +8,3072,4096 +8,14336,4096 +8,4096,7168 +8,4096,2048 +16,3072,4096 +16,14336,4096 +16,4096,7168 +16,4096,2048 +17,4096,2048 +17,14336,4096 +17,4096,7168 +17,3072,4096 +24,4096,7168 +24,4096,2048 +24,14336,4096 +24,3072,4096 +32,4096,2048 +32,14336,4096 +32,4096,7168 +32,3072,4096 +40,3072,4096 +40,4096,2048 +40,14336,4096 +40,4096,7168 +48,3072,4096 +48,4096,7168 +48,4096,2048 +48,14336,4096 +56,3072,4096 +56,4096,7168 +56,4096,2048 +56,14336,4096 +64,3072,4096 +64,4096,2048 +64,4096,7168 +64,14336,4096 +72,3072,4096 +72,4096,7168 +72,4096,2048 +72,14336,4096 +80,3072,4096 +80,4096,2048 +80,14336,4096 +80,4096,7168 +88,3072,4096 +88,4096,7168 +88,4096,2048 +88,14336,4096 +96,3072,4096 +96,4096,7168 +96,4096,2048 +96,14336,4096 +104,3072,4096 +104,14336,4096 +104,4096,7168 +104,4096,2048 +112,3072,4096 +112,4096,7168 +112,4096,2048 +112,14336,4096 +120,3072,4096 +120,4096,7168 +120,4096,2048 +120,14336,4096 +128,14336,4096 +128,3072,4096 +128,4096,2048 +128,4096,7168 +136,4096,7168 +136,4096,2048 +136,14336,4096 +136,3072,4096 +144,3072,4096 +144,4096,7168 +144,4096,2048 +144,14336,4096 +152,3072,4096 +152,4096,7168 +152,4096,2048 +152,14336,4096 +160,4096,7168 +160,4096,2048 +160,14336,4096 +160,3072,4096 +168,3072,4096 +168,4096,7168 +168,4096,2048 +168,14336,4096 +176,4096,7168 +176,3072,4096 +176,4096,2048 +176,14336,4096 +184,3072,4096 +184,4096,7168 +184,4096,2048 +184,14336,4096 +192,4096,2048 +192,4096,7168 +192,14336,4096 +192,3072,4096 +200,4096,7168 +200,4096,2048 +200,14336,4096 +200,3072,4096 +208,4096,7168 +208,4096,2048 +208,14336,4096 +208,3072,4096 +216,4096,2048 +216,4096,7168 +216,14336,4096 +216,3072,4096 +224,3072,4096 +224,4096,7168 +224,4096,2048 +224,14336,4096 +232,3072,4096 +232,4096,7168 +232,4096,2048 +232,14336,4096 +240,3072,4096 +240,14336,4096 +240,4096,7168 +240,4096,2048 +248,3072,4096 +248,4096,7168 +248,4096,2048 +248,14336,4096 +256,4096,2048 +256,14336,4096 +256,4096,7168 +256,3072,4096 +512,14336,4096 +512,3072,4096 +512,4096,2048 +512,4096,7168 +1024,4096,7168 +1024,14336,4096 +1024,3072,4096 +1024,4096,2048 +1536,3072,4096 +1536,4096,2048 +1536,4096,7168 +1536,14336,4096 +2048,14336,4096 +2048,4096,7168 +2048,4096,2048 +2048,3072,4096 +3072,4096,7168 +3072,4096,2048 +3072,14336,4096 +3072,3072,4096 +4096,4096,2048 +4096,3072,4096 +4096,14336,4096 +4096,4096,7168 +8192,4096,2048 +8192,3072,4096 +8192,4096,7168 +8192,14336,4096 +16384,4096,7168 +16384,4096,2048 +16384,14336,4096 +16384,3072,4096 +18432,3072,4096 +18432,4096,2048 +18432,4096,7168 +18432,14336,4096 +20480,4096,7168 +20480,4096,2048 +20480,14336,4096 +20480,3072,4096 +32768,3072,4096 +32768,4096,7168 +32768,4096,2048 +32768,14336,4096 diff --git a/aiter/configs/a8w8_gemm_model_config/Llama-3.1-8B-Instruct-TP4.csv b/aiter/configs/a8w8_gemm_model_config/Llama-3.1-8B-Instruct-TP4.csv new file mode 100644 index 00000000..a44a194e --- /dev/null +++ b/aiter/configs/a8w8_gemm_model_config/Llama-3.1-8B-Instruct-TP4.csv @@ -0,0 +1,189 @@ +M,N,K +1,4096,3584 +1,4096,1024 +1,1536,4096 +1,7168,4096 +2,4096,3584 +2,7168,4096 +2,1536,4096 +2,4096,1024 +4,1536,4096 +4,4096,3584 +4,4096,1024 +4,7168,4096 +8,1536,4096 +8,4096,3584 +8,4096,1024 +8,7168,4096 +16,4096,3584 +16,4096,1024 +16,7168,4096 +16,1536,4096 +17,1536,4096 +17,4096,3584 +17,4096,1024 +17,7168,4096 +24,4096,1024 +24,4096,3584 +24,7168,4096 +24,1536,4096 +32,4096,3584 +32,4096,1024 +32,1536,4096 +32,7168,4096 +40,1536,4096 +40,7168,4096 +40,4096,3584 +40,4096,1024 +48,4096,3584 +48,7168,4096 +48,1536,4096 +48,4096,1024 +56,1536,4096 +56,4096,3584 +56,4096,1024 +56,7168,4096 +64,1536,4096 +64,4096,3584 +64,4096,1024 +64,7168,4096 +72,1536,4096 +72,4096,3584 +72,4096,1024 +72,7168,4096 +80,7168,4096 +80,1536,4096 +80,4096,3584 +80,4096,1024 +88,1536,4096 +88,7168,4096 +88,4096,1024 +88,4096,3584 +96,1536,4096 +96,4096,3584 +96,4096,1024 +96,7168,4096 +104,4096,3584 +104,4096,1024 +104,7168,4096 +104,1536,4096 +112,4096,1024 +112,7168,4096 +112,4096,3584 +112,1536,4096 +120,4096,3584 +120,7168,4096 +120,1536,4096 +120,4096,1024 +128,7168,4096 +128,1536,4096 +128,4096,3584 +128,4096,1024 +136,7168,4096 +136,1536,4096 +136,4096,3584 +136,4096,1024 +144,1536,4096 +144,4096,3584 +144,4096,1024 +144,7168,4096 +152,7168,4096 +152,4096,1024 +152,4096,3584 +152,1536,4096 +160,4096,3584 +160,4096,1024 +160,1536,4096 +160,7168,4096 +168,4096,3584 +168,4096,1024 +168,7168,4096 +168,1536,4096 +176,1536,4096 +176,7168,4096 +176,4096,3584 +176,4096,1024 +184,4096,1024 +184,4096,3584 +184,1536,4096 +184,7168,4096 +192,1536,4096 +192,4096,3584 +192,4096,1024 +192,7168,4096 +200,4096,3584 +200,4096,1024 +200,7168,4096 +200,1536,4096 +208,4096,3584 +208,4096,1024 +208,7168,4096 +208,1536,4096 +216,4096,3584 +216,4096,1024 +216,7168,4096 +216,1536,4096 +224,1536,4096 +224,7168,4096 +224,4096,1024 +224,4096,3584 +232,1536,4096 +232,4096,3584 +232,4096,1024 +232,7168,4096 +240,1536,4096 +240,4096,3584 +240,4096,1024 +240,7168,4096 +248,4096,3584 +248,4096,1024 +248,1536,4096 +248,7168,4096 +256,1536,4096 +256,4096,3584 +256,4096,1024 +256,7168,4096 +512,1536,4096 +512,4096,3584 +512,4096,1024 +512,7168,4096 +1024,4096,3584 +1024,4096,1024 +1024,7168,4096 +1024,1536,4096 +1536,1536,4096 +1536,4096,3584 +1536,4096,1024 +1536,7168,4096 +2048,4096,3584 +2048,4096,1024 +2048,7168,4096 +2048,1536,4096 +3072,4096,3584 +3072,4096,1024 +3072,7168,4096 +3072,1536,4096 +4096,1536,4096 +4096,4096,3584 +4096,4096,1024 +4096,7168,4096 +8192,1536,4096 +8192,4096,3584 +8192,4096,1024 +8192,7168,4096 +16384,4096,1024 +16384,7168,4096 +16384,1536,4096 +16384,4096,3584 +18432,1536,4096 +18432,4096,3584 +18432,4096,1024 +18432,7168,4096 +20480,4096,3584 +20480,4096,1024 +20480,7168,4096 +20480,1536,4096 +32768,1536,4096 +32768,7168,4096 +32768,4096,3584 +32768,4096,1024 diff --git a/aiter/configs/a8w8_gemm_model_config/Llama-3.1-8B-Instruct-TP8.csv b/aiter/configs/a8w8_gemm_model_config/Llama-3.1-8B-Instruct-TP8.csv new file mode 100644 index 00000000..a0ea181a --- /dev/null +++ b/aiter/configs/a8w8_gemm_model_config/Llama-3.1-8B-Instruct-TP8.csv @@ -0,0 +1,189 @@ +M,N,K +1,4096,512 +1,4096,1792 +1,3584,4096 +1,768,4096 +2,4096,1792 +2,3584,4096 +2,4096,512 +2,768,4096 +4,3584,4096 +4,768,4096 +4,4096,512 +4,4096,1792 +8,768,4096 +8,4096,512 +8,4096,1792 +8,3584,4096 +16,4096,512 +16,4096,1792 +16,3584,4096 +16,768,4096 +17,4096,1792 +17,3584,4096 +17,4096,512 +17,768,4096 +24,4096,1792 +24,3584,4096 +24,4096,512 +24,768,4096 +32,4096,1792 +32,4096,512 +32,768,4096 +32,3584,4096 +40,768,4096 +40,3584,4096 +40,4096,1792 +40,4096,512 +48,4096,512 +48,3584,4096 +48,768,4096 +48,4096,1792 +56,768,4096 +56,4096,512 +56,4096,1792 +56,3584,4096 +64,768,4096 +64,4096,1792 +64,4096,512 +64,3584,4096 +72,768,4096 +72,4096,512 +72,4096,1792 +72,3584,4096 +80,4096,1792 +80,768,4096 +80,3584,4096 +80,4096,512 +88,3584,4096 +88,768,4096 +88,4096,1792 +88,4096,512 +96,768,4096 +96,4096,512 +96,4096,1792 +96,3584,4096 +104,4096,512 +104,4096,1792 +104,768,4096 +104,3584,4096 +112,4096,1792 +112,4096,512 +112,768,4096 +112,3584,4096 +120,4096,512 +120,3584,4096 +120,768,4096 +120,4096,1792 +128,768,4096 +128,3584,4096 +128,4096,1792 +128,4096,512 +136,4096,1792 +136,3584,4096 +136,768,4096 +136,4096,512 +144,768,4096 +144,4096,512 +144,4096,1792 +144,3584,4096 +152,4096,1792 +152,4096,512 +152,3584,4096 +152,768,4096 +160,4096,1792 +160,4096,512 +160,3584,4096 +160,768,4096 +168,4096,512 +168,4096,1792 +168,3584,4096 +168,768,4096 +176,3584,4096 +176,768,4096 +176,4096,512 +176,4096,1792 +184,4096,1792 +184,4096,512 +184,3584,4096 +184,768,4096 +192,768,4096 +192,3584,4096 +192,4096,512 +192,4096,1792 +200,4096,512 +200,4096,1792 +200,3584,4096 +200,768,4096 +208,4096,1792 +208,3584,4096 +208,768,4096 +208,4096,512 +216,4096,1792 +216,3584,4096 +216,4096,512 +216,768,4096 +224,768,4096 +224,4096,1792 +224,4096,512 +224,3584,4096 +232,3584,4096 +232,768,4096 +232,4096,512 +232,4096,1792 +240,3584,4096 +240,768,4096 +240,4096,512 +240,4096,1792 +248,4096,512 +248,4096,1792 +248,3584,4096 +248,768,4096 +256,4096,1792 +256,3584,4096 +256,4096,512 +256,768,4096 +512,768,4096 +512,4096,1792 +512,3584,4096 +512,4096,512 +1024,4096,512 +1024,768,4096 +1024,4096,1792 +1024,3584,4096 +1536,4096,1792 +1536,3584,4096 +1536,4096,512 +1536,768,4096 +2048,4096,512 +2048,4096,1792 +2048,3584,4096 +2048,768,4096 +3072,4096,512 +3072,4096,1792 +3072,3584,4096 +3072,768,4096 +4096,768,4096 +4096,3584,4096 +4096,4096,1792 +4096,4096,512 +8192,3584,4096 +8192,4096,1792 +8192,4096,512 +8192,768,4096 +16384,4096,1792 +16384,3584,4096 +16384,768,4096 +16384,4096,512 +18432,4096,1792 +18432,4096,512 +18432,3584,4096 +18432,768,4096 +20480,4096,1792 +20480,3584,4096 +20480,768,4096 +20480,4096,512 +32768,3584,4096 +32768,768,4096 +32768,4096,512 +32768,4096,1792 diff --git a/aiter/configs/a8w8_gemm_model_config/Mixtral-8x22B-Instruct-v0.1-TP2.csv b/aiter/configs/a8w8_gemm_model_config/Mixtral-8x22B-Instruct-v0.1-TP2.csv new file mode 100644 index 00000000..bf12b198 --- /dev/null +++ b/aiter/configs/a8w8_gemm_model_config/Mixtral-8x22B-Instruct-v0.1-TP2.csv @@ -0,0 +1,189 @@ +M,N,K +1,6144,3072 +1,16384,6144 +1,4096,6144 +1,6144,16384 +2,4096,6144 +2,6144,3072 +2,6144,16384 +2,16384,6144 +4,4096,6144 +4,16384,6144 +4,6144,16384 +4,6144,3072 +8,6144,3072 +8,16384,6144 +8,6144,16384 +8,4096,6144 +16,16384,6144 +16,4096,6144 +16,6144,16384 +16,6144,3072 +17,6144,16384 +17,6144,3072 +17,16384,6144 +17,4096,6144 +24,6144,3072 +24,16384,6144 +24,4096,6144 +24,6144,16384 +32,6144,3072 +32,4096,6144 +32,6144,16384 +32,16384,6144 +40,6144,3072 +40,16384,6144 +40,4096,6144 +40,6144,16384 +48,6144,3072 +48,16384,6144 +48,6144,16384 +48,4096,6144 +56,6144,3072 +56,16384,6144 +56,6144,16384 +56,4096,6144 +64,16384,6144 +64,6144,16384 +64,4096,6144 +64,6144,3072 +72,16384,6144 +72,6144,3072 +72,6144,16384 +72,4096,6144 +80,6144,16384 +80,16384,6144 +80,4096,6144 +80,6144,3072 +88,16384,6144 +88,6144,16384 +88,6144,3072 +88,4096,6144 +96,16384,6144 +96,6144,3072 +96,6144,16384 +96,4096,6144 +104,16384,6144 +104,6144,3072 +104,6144,16384 +104,4096,6144 +112,6144,3072 +112,16384,6144 +112,6144,16384 +112,4096,6144 +120,6144,16384 +120,4096,6144 +120,6144,3072 +120,16384,6144 +128,16384,6144 +128,6144,16384 +128,4096,6144 +128,6144,3072 +136,16384,6144 +136,6144,3072 +136,4096,6144 +136,6144,16384 +144,16384,6144 +144,6144,3072 +144,6144,16384 +144,4096,6144 +152,6144,16384 +152,6144,3072 +152,16384,6144 +152,4096,6144 +160,16384,6144 +160,6144,16384 +160,4096,6144 +160,6144,3072 +168,16384,6144 +168,6144,16384 +168,4096,6144 +168,6144,3072 +176,6144,3072 +176,16384,6144 +176,6144,16384 +176,4096,6144 +184,6144,3072 +184,16384,6144 +184,6144,16384 +184,4096,6144 +192,6144,3072 +192,16384,6144 +192,6144,16384 +192,4096,6144 +200,6144,3072 +200,16384,6144 +200,4096,6144 +200,6144,16384 +208,6144,16384 +208,6144,3072 +208,16384,6144 +208,4096,6144 +216,16384,6144 +216,4096,6144 +216,6144,16384 +216,6144,3072 +224,6144,16384 +224,6144,3072 +224,16384,6144 +224,4096,6144 +232,16384,6144 +232,6144,3072 +232,4096,6144 +232,6144,16384 +240,6144,3072 +240,16384,6144 +240,4096,6144 +240,6144,16384 +248,6144,3072 +248,4096,6144 +248,6144,16384 +248,16384,6144 +256,6144,16384 +256,6144,3072 +256,16384,6144 +256,4096,6144 +512,6144,16384 +512,6144,3072 +512,16384,6144 +512,4096,6144 +1024,4096,6144 +1024,6144,16384 +1024,6144,3072 +1024,16384,6144 +1536,16384,6144 +1536,6144,16384 +1536,6144,3072 +1536,4096,6144 +2048,16384,6144 +2048,6144,16384 +2048,4096,6144 +2048,6144,3072 +3072,4096,6144 +3072,6144,3072 +3072,16384,6144 +3072,6144,16384 +4096,16384,6144 +4096,6144,3072 +4096,6144,16384 +4096,4096,6144 +8192,6144,3072 +8192,16384,6144 +8192,4096,6144 +8192,6144,16384 +16384,6144,16384 +16384,6144,3072 +16384,16384,6144 +16384,4096,6144 +18432,6144,16384 +18432,6144,3072 +18432,16384,6144 +18432,4096,6144 +20480,6144,16384 +20480,6144,3072 +20480,16384,6144 +20480,4096,6144 +65536,6144,3072 +65536,16384,6144 +65536,4096,6144 +65536,6144,16384 diff --git a/aiter/configs/a8w8_gemm_model_config/Mixtral-8x22B-Instruct-v0.1-TP4.csv b/aiter/configs/a8w8_gemm_model_config/Mixtral-8x22B-Instruct-v0.1-TP4.csv new file mode 100644 index 00000000..05f238ce --- /dev/null +++ b/aiter/configs/a8w8_gemm_model_config/Mixtral-8x22B-Instruct-v0.1-TP4.csv @@ -0,0 +1,189 @@ +M,N,K +1,16384,6144 +1,6144,1536 +1,6144,16384 +1,2048,6144 +2,2048,6144 +2,6144,1536 +2,6144,16384 +2,16384,6144 +4,2048,6144 +4,16384,6144 +4,6144,1536 +4,6144,16384 +8,16384,6144 +8,6144,1536 +8,6144,16384 +8,2048,6144 +16,16384,6144 +16,6144,1536 +16,6144,16384 +16,2048,6144 +17,6144,16384 +17,16384,6144 +17,2048,6144 +17,6144,1536 +24,16384,6144 +24,2048,6144 +24,6144,1536 +24,6144,16384 +32,6144,16384 +32,2048,6144 +32,6144,1536 +32,16384,6144 +40,2048,6144 +40,16384,6144 +40,6144,1536 +40,6144,16384 +48,2048,6144 +48,6144,1536 +48,16384,6144 +48,6144,16384 +56,16384,6144 +56,6144,1536 +56,6144,16384 +56,2048,6144 +64,6144,1536 +64,16384,6144 +64,6144,16384 +64,2048,6144 +72,16384,6144 +72,6144,1536 +72,6144,16384 +72,2048,6144 +80,6144,1536 +80,6144,16384 +80,2048,6144 +80,16384,6144 +88,2048,6144 +88,6144,1536 +88,16384,6144 +88,6144,16384 +96,6144,1536 +96,16384,6144 +96,6144,16384 +96,2048,6144 +104,6144,1536 +104,16384,6144 +104,6144,16384 +104,2048,6144 +112,6144,1536 +112,16384,6144 +112,6144,16384 +112,2048,6144 +120,6144,1536 +120,6144,16384 +120,2048,6144 +120,16384,6144 +128,2048,6144 +128,6144,1536 +128,16384,6144 +128,6144,16384 +136,2048,6144 +136,16384,6144 +136,6144,1536 +136,6144,16384 +144,16384,6144 +144,6144,1536 +144,6144,16384 +144,2048,6144 +152,2048,6144 +152,6144,1536 +152,6144,16384 +152,16384,6144 +160,6144,1536 +160,16384,6144 +160,6144,16384 +160,2048,6144 +168,16384,6144 +168,6144,1536 +168,6144,16384 +168,2048,6144 +176,2048,6144 +176,6144,1536 +176,16384,6144 +176,6144,16384 +184,6144,1536 +184,16384,6144 +184,6144,16384 +184,2048,6144 +192,16384,6144 +192,6144,16384 +192,2048,6144 +192,6144,1536 +200,16384,6144 +200,6144,1536 +200,6144,16384 +200,2048,6144 +208,6144,16384 +208,2048,6144 +208,16384,6144 +208,6144,1536 +216,16384,6144 +216,2048,6144 +216,6144,1536 +216,6144,16384 +224,2048,6144 +224,6144,1536 +224,6144,16384 +224,16384,6144 +232,16384,6144 +232,6144,1536 +232,6144,16384 +232,2048,6144 +240,16384,6144 +240,6144,1536 +240,6144,16384 +240,2048,6144 +248,6144,1536 +248,6144,16384 +248,16384,6144 +248,2048,6144 +256,6144,1536 +256,6144,16384 +256,16384,6144 +256,2048,6144 +512,6144,1536 +512,6144,16384 +512,16384,6144 +512,2048,6144 +1024,2048,6144 +1024,6144,1536 +1024,6144,16384 +1024,16384,6144 +1536,6144,1536 +1536,16384,6144 +1536,6144,16384 +1536,2048,6144 +2048,16384,6144 +2048,6144,1536 +2048,6144,16384 +2048,2048,6144 +3072,2048,6144 +3072,16384,6144 +3072,6144,1536 +3072,6144,16384 +4096,6144,1536 +4096,16384,6144 +4096,6144,16384 +4096,2048,6144 +8192,16384,6144 +8192,6144,1536 +8192,6144,16384 +8192,2048,6144 +16384,6144,16384 +16384,2048,6144 +16384,16384,6144 +16384,6144,1536 +18432,6144,1536 +18432,6144,16384 +18432,16384,6144 +18432,2048,6144 +20480,6144,16384 +20480,2048,6144 +20480,16384,6144 +20480,6144,1536 +65536,16384,6144 +65536,6144,1536 +65536,6144,16384 +65536,2048,6144 diff --git a/aiter/configs/a8w8_gemm_model_config/Mixtral-8x22B-Instruct-v0.1-TP8.csv b/aiter/configs/a8w8_gemm_model_config/Mixtral-8x22B-Instruct-v0.1-TP8.csv new file mode 100644 index 00000000..195089f7 --- /dev/null +++ b/aiter/configs/a8w8_gemm_model_config/Mixtral-8x22B-Instruct-v0.1-TP8.csv @@ -0,0 +1,189 @@ +M,N,K +1,6144,768 +1,1024,6144 +1,16384,6144 +1,6144,16384 +2,1024,6144 +2,6144,768 +2,6144,16384 +2,16384,6144 +4,16384,6144 +4,6144,768 +4,1024,6144 +4,6144,16384 +8,6144,768 +8,1024,6144 +8,16384,6144 +8,6144,16384 +16,6144,768 +16,1024,6144 +16,16384,6144 +16,6144,16384 +17,6144,16384 +17,6144,768 +17,1024,6144 +17,16384,6144 +24,16384,6144 +24,1024,6144 +24,6144,768 +24,6144,16384 +32,6144,768 +32,1024,6144 +32,6144,16384 +32,16384,6144 +40,6144,768 +40,1024,6144 +40,16384,6144 +40,6144,16384 +48,6144,768 +48,1024,6144 +48,16384,6144 +48,6144,16384 +56,6144,768 +56,1024,6144 +56,16384,6144 +56,6144,16384 +64,1024,6144 +64,16384,6144 +64,6144,16384 +64,6144,768 +72,1024,6144 +72,16384,6144 +72,6144,768 +72,6144,16384 +80,6144,768 +80,6144,16384 +80,1024,6144 +80,16384,6144 +88,6144,768 +88,16384,6144 +88,6144,16384 +88,1024,6144 +96,1024,6144 +96,16384,6144 +96,6144,768 +96,6144,16384 +104,1024,6144 +104,16384,6144 +104,6144,768 +104,6144,16384 +112,6144,768 +112,1024,6144 +112,16384,6144 +112,6144,16384 +120,1024,6144 +120,6144,16384 +120,16384,6144 +120,6144,768 +128,6144,768 +128,1024,6144 +128,16384,6144 +128,6144,16384 +136,16384,6144 +136,6144,768 +136,1024,6144 +136,6144,16384 +144,1024,6144 +144,16384,6144 +144,6144,768 +144,6144,16384 +152,6144,768 +152,6144,16384 +152,1024,6144 +152,16384,6144 +160,6144,768 +160,1024,6144 +160,16384,6144 +160,6144,16384 +168,6144,768 +168,1024,6144 +168,16384,6144 +168,6144,16384 +176,1024,6144 +176,6144,768 +176,16384,6144 +176,6144,16384 +184,6144,768 +184,1024,6144 +184,16384,6144 +184,6144,16384 +192,6144,768 +192,1024,6144 +192,16384,6144 +192,6144,16384 +200,6144,768 +200,1024,6144 +200,16384,6144 +200,6144,16384 +208,6144,16384 +208,6144,768 +208,1024,6144 +208,16384,6144 +216,16384,6144 +216,6144,768 +216,1024,6144 +216,6144,16384 +224,6144,768 +224,6144,16384 +224,1024,6144 +224,16384,6144 +232,16384,6144 +232,6144,768 +232,1024,6144 +232,6144,16384 +240,6144,768 +240,1024,6144 +240,16384,6144 +240,6144,16384 +248,6144,768 +248,1024,6144 +248,6144,16384 +248,16384,6144 +256,6144,16384 +256,1024,6144 +256,16384,6144 +256,6144,768 +512,6144,768 +512,6144,16384 +512,1024,6144 +512,16384,6144 +1024,6144,768 +1024,6144,16384 +1024,1024,6144 +1024,16384,6144 +1536,6144,768 +1536,1024,6144 +1536,16384,6144 +1536,6144,16384 +2048,1024,6144 +2048,16384,6144 +2048,6144,16384 +2048,6144,768 +3072,6144,768 +3072,1024,6144 +3072,16384,6144 +3072,6144,16384 +4096,1024,6144 +4096,16384,6144 +4096,6144,768 +4096,6144,16384 +8192,6144,768 +8192,1024,6144 +8192,16384,6144 +8192,6144,16384 +16384,6144,16384 +16384,6144,768 +16384,1024,6144 +16384,16384,6144 +18432,6144,16384 +18432,6144,768 +18432,1024,6144 +18432,16384,6144 +20480,6144,16384 +20480,6144,768 +20480,1024,6144 +20480,16384,6144 +65536,6144,768 +65536,1024,6144 +65536,16384,6144 +65536,6144,16384 diff --git a/aiter/configs/a8w8_gemm_model_config/Mixtral-8x7B-Instruct-v0.1-TP1.csv b/aiter/configs/a8w8_gemm_model_config/Mixtral-8x7B-Instruct-v0.1-TP1.csv new file mode 100644 index 00000000..7ff7b198 --- /dev/null +++ b/aiter/configs/a8w8_gemm_model_config/Mixtral-8x7B-Instruct-v0.1-TP1.csv @@ -0,0 +1,189 @@ +M,N,K +1,14336,4096 +1,6144,4096 +1,4096,4096 +1,4096,14336 +2,14336,4096 +2,4096,14336 +2,4096,4096 +2,6144,4096 +4,14336,4096 +4,4096,14336 +4,6144,4096 +4,4096,4096 +8,14336,4096 +8,6144,4096 +8,4096,4096 +8,4096,14336 +16,14336,4096 +16,6144,4096 +16,4096,4096 +16,4096,14336 +17,14336,4096 +17,4096,14336 +17,4096,4096 +17,6144,4096 +24,4096,4096 +24,14336,4096 +24,4096,14336 +24,6144,4096 +32,14336,4096 +32,4096,14336 +32,6144,4096 +32,4096,4096 +40,14336,4096 +40,4096,14336 +40,6144,4096 +40,4096,4096 +48,6144,4096 +48,4096,4096 +48,14336,4096 +48,4096,14336 +56,6144,4096 +56,4096,4096 +56,14336,4096 +56,4096,14336 +64,6144,4096 +64,4096,14336 +64,4096,4096 +64,14336,4096 +72,6144,4096 +72,4096,4096 +72,14336,4096 +72,4096,14336 +80,14336,4096 +80,6144,4096 +80,4096,14336 +80,4096,4096 +88,6144,4096 +88,4096,4096 +88,14336,4096 +88,4096,14336 +96,6144,4096 +96,4096,4096 +96,14336,4096 +96,4096,14336 +104,14336,4096 +104,6144,4096 +104,4096,4096 +104,4096,14336 +112,4096,4096 +112,14336,4096 +112,6144,4096 +112,4096,14336 +120,4096,14336 +120,4096,4096 +120,14336,4096 +120,6144,4096 +128,14336,4096 +128,6144,4096 +128,4096,14336 +128,4096,4096 +136,14336,4096 +136,4096,14336 +136,6144,4096 +136,4096,4096 +144,6144,4096 +144,4096,4096 +144,14336,4096 +144,4096,14336 +152,6144,4096 +152,4096,4096 +152,14336,4096 +152,4096,14336 +160,4096,4096 +160,6144,4096 +160,4096,14336 +160,14336,4096 +168,6144,4096 +168,4096,4096 +168,14336,4096 +168,4096,14336 +176,4096,4096 +176,14336,4096 +176,6144,4096 +176,4096,14336 +184,4096,4096 +184,14336,4096 +184,6144,4096 +184,4096,14336 +192,6144,4096 +192,4096,14336 +192,4096,4096 +192,14336,4096 +200,6144,4096 +200,4096,4096 +200,14336,4096 +200,4096,14336 +208,4096,4096 +208,14336,4096 +208,4096,14336 +208,6144,4096 +216,4096,14336 +216,4096,4096 +216,14336,4096 +216,6144,4096 +224,6144,4096 +224,4096,4096 +224,14336,4096 +224,4096,14336 +232,4096,14336 +232,6144,4096 +232,4096,4096 +232,14336,4096 +240,14336,4096 +240,6144,4096 +240,4096,4096 +240,4096,14336 +248,6144,4096 +248,4096,4096 +248,14336,4096 +248,4096,14336 +256,14336,4096 +256,4096,14336 +256,4096,4096 +256,6144,4096 +512,14336,4096 +512,6144,4096 +512,4096,14336 +512,4096,4096 +1024,4096,4096 +1024,14336,4096 +1024,6144,4096 +1024,4096,14336 +1536,6144,4096 +1536,4096,14336 +1536,4096,4096 +1536,14336,4096 +2048,6144,4096 +2048,14336,4096 +2048,4096,4096 +2048,4096,14336 +3072,4096,4096 +3072,14336,4096 +3072,4096,14336 +3072,6144,4096 +4096,4096,14336 +4096,14336,4096 +4096,6144,4096 +4096,4096,4096 +8192,4096,14336 +8192,6144,4096 +8192,4096,4096 +8192,14336,4096 +16384,4096,4096 +16384,14336,4096 +16384,4096,14336 +16384,6144,4096 +18432,4096,14336 +18432,6144,4096 +18432,4096,4096 +18432,14336,4096 +20480,4096,4096 +20480,14336,4096 +20480,4096,14336 +20480,6144,4096 +32768,6144,4096 +32768,4096,4096 +32768,14336,4096 +32768,4096,14336 diff --git a/aiter/configs/a8w8_gemm_model_config/Mixtral-8x7B-Instruct-v0.1-TP2.csv b/aiter/configs/a8w8_gemm_model_config/Mixtral-8x7B-Instruct-v0.1-TP2.csv new file mode 100644 index 00000000..7c7256c8 --- /dev/null +++ b/aiter/configs/a8w8_gemm_model_config/Mixtral-8x7B-Instruct-v0.1-TP2.csv @@ -0,0 +1,189 @@ +M,N,K +1,14336,4096 +1,4096,2048 +1,4096,14336 +1,3072,4096 +2,4096,2048 +2,14336,4096 +2,4096,14336 +2,3072,4096 +4,14336,4096 +4,4096,14336 +4,3072,4096 +4,4096,2048 +8,3072,4096 +8,14336,4096 +8,4096,2048 +8,4096,14336 +16,3072,4096 +16,14336,4096 +16,4096,2048 +16,4096,14336 +17,4096,2048 +17,14336,4096 +17,4096,14336 +17,3072,4096 +24,4096,2048 +24,14336,4096 +24,4096,14336 +24,3072,4096 +32,4096,2048 +32,14336,4096 +32,4096,14336 +32,3072,4096 +40,3072,4096 +40,4096,2048 +40,14336,4096 +40,4096,14336 +48,3072,4096 +48,4096,2048 +48,14336,4096 +48,4096,14336 +56,3072,4096 +56,4096,2048 +56,14336,4096 +56,4096,14336 +64,3072,4096 +64,4096,2048 +64,4096,14336 +64,14336,4096 +72,3072,4096 +72,4096,2048 +72,14336,4096 +72,4096,14336 +80,3072,4096 +80,4096,2048 +80,14336,4096 +80,4096,14336 +88,3072,4096 +88,4096,2048 +88,14336,4096 +88,4096,14336 +96,3072,4096 +96,4096,2048 +96,14336,4096 +96,4096,14336 +104,3072,4096 +104,14336,4096 +104,4096,2048 +104,4096,14336 +112,3072,4096 +112,4096,2048 +112,14336,4096 +112,4096,14336 +120,4096,14336 +120,3072,4096 +120,4096,2048 +120,14336,4096 +128,14336,4096 +128,3072,4096 +128,4096,2048 +128,4096,14336 +136,4096,2048 +136,14336,4096 +136,4096,14336 +136,3072,4096 +144,3072,4096 +144,4096,2048 +144,14336,4096 +144,4096,14336 +152,3072,4096 +152,4096,2048 +152,14336,4096 +152,4096,14336 +160,4096,2048 +160,4096,14336 +160,14336,4096 +160,3072,4096 +168,3072,4096 +168,4096,2048 +168,14336,4096 +168,4096,14336 +176,3072,4096 +176,4096,2048 +176,14336,4096 +176,4096,14336 +184,3072,4096 +184,4096,2048 +184,14336,4096 +184,4096,14336 +192,4096,2048 +192,4096,14336 +192,14336,4096 +192,3072,4096 +200,4096,2048 +200,14336,4096 +200,4096,14336 +200,3072,4096 +208,4096,2048 +208,14336,4096 +208,4096,14336 +208,3072,4096 +216,4096,2048 +216,4096,14336 +216,14336,4096 +216,3072,4096 +224,3072,4096 +224,4096,2048 +224,14336,4096 +224,4096,14336 +232,3072,4096 +232,4096,14336 +232,4096,2048 +232,14336,4096 +240,3072,4096 +240,14336,4096 +240,4096,2048 +240,4096,14336 +248,3072,4096 +248,4096,2048 +248,14336,4096 +248,4096,14336 +256,4096,2048 +256,14336,4096 +256,4096,14336 +256,3072,4096 +512,14336,4096 +512,3072,4096 +512,4096,2048 +512,4096,14336 +1024,14336,4096 +1024,3072,4096 +1024,4096,2048 +1024,4096,14336 +1536,3072,4096 +1536,4096,2048 +1536,4096,14336 +1536,14336,4096 +2048,14336,4096 +2048,4096,2048 +2048,4096,14336 +2048,3072,4096 +3072,4096,2048 +3072,14336,4096 +3072,4096,14336 +3072,3072,4096 +4096,4096,2048 +4096,4096,14336 +4096,3072,4096 +4096,14336,4096 +8192,4096,2048 +8192,3072,4096 +8192,4096,14336 +8192,14336,4096 +16384,4096,2048 +16384,14336,4096 +16384,4096,14336 +16384,3072,4096 +18432,3072,4096 +18432,4096,2048 +18432,4096,14336 +18432,14336,4096 +20480,4096,2048 +20480,14336,4096 +20480,4096,14336 +20480,3072,4096 +32768,3072,4096 +32768,4096,2048 +32768,14336,4096 +32768,4096,14336 diff --git a/aiter/configs/a8w8_gemm_model_config/Mixtral-8x7B-Instruct-v0.1-TP4.csv b/aiter/configs/a8w8_gemm_model_config/Mixtral-8x7B-Instruct-v0.1-TP4.csv new file mode 100644 index 00000000..da2e030c --- /dev/null +++ b/aiter/configs/a8w8_gemm_model_config/Mixtral-8x7B-Instruct-v0.1-TP4.csv @@ -0,0 +1,189 @@ +M,N,K +1,14336,4096 +1,4096,1024 +1,4096,14336 +1,1536,4096 +2,14336,4096 +2,4096,14336 +2,1536,4096 +2,4096,1024 +4,14336,4096 +4,4096,14336 +4,1536,4096 +4,4096,1024 +8,1536,4096 +8,14336,4096 +8,4096,1024 +8,4096,14336 +16,14336,4096 +16,4096,1024 +16,4096,14336 +16,1536,4096 +17,1536,4096 +17,14336,4096 +17,4096,14336 +17,4096,1024 +24,4096,1024 +24,14336,4096 +24,4096,14336 +24,1536,4096 +32,14336,4096 +32,4096,14336 +32,4096,1024 +32,1536,4096 +40,1536,4096 +40,14336,4096 +40,4096,14336 +40,4096,1024 +48,1536,4096 +48,4096,1024 +48,14336,4096 +48,4096,14336 +56,1536,4096 +56,4096,1024 +56,14336,4096 +56,4096,14336 +64,1536,4096 +64,4096,14336 +64,4096,1024 +64,14336,4096 +72,1536,4096 +72,4096,1024 +72,14336,4096 +72,4096,14336 +80,14336,4096 +80,4096,14336 +80,1536,4096 +80,4096,1024 +88,1536,4096 +88,4096,1024 +88,14336,4096 +88,4096,14336 +96,1536,4096 +96,4096,1024 +96,14336,4096 +96,4096,14336 +104,14336,4096 +104,4096,1024 +104,4096,14336 +104,1536,4096 +112,4096,1024 +112,14336,4096 +112,4096,14336 +112,1536,4096 +120,4096,14336 +120,1536,4096 +120,4096,1024 +120,14336,4096 +128,14336,4096 +128,1536,4096 +128,4096,14336 +128,4096,1024 +136,14336,4096 +136,4096,14336 +136,1536,4096 +136,4096,1024 +144,1536,4096 +144,4096,1024 +144,14336,4096 +144,4096,14336 +152,4096,1024 +152,14336,4096 +152,4096,14336 +152,1536,4096 +160,4096,14336 +160,4096,1024 +160,14336,4096 +160,1536,4096 +168,4096,1024 +168,14336,4096 +168,4096,14336 +168,1536,4096 +176,1536,4096 +176,4096,1024 +176,14336,4096 +176,4096,14336 +184,4096,1024 +184,14336,4096 +184,4096,14336 +184,1536,4096 +192,1536,4096 +192,4096,14336 +192,4096,1024 +192,14336,4096 +200,4096,1024 +200,14336,4096 +200,4096,14336 +200,1536,4096 +208,4096,1024 +208,14336,4096 +208,4096,14336 +208,1536,4096 +216,4096,14336 +216,4096,1024 +216,14336,4096 +216,1536,4096 +224,1536,4096 +224,4096,1024 +224,14336,4096 +224,4096,14336 +232,1536,4096 +232,4096,14336 +232,4096,1024 +232,14336,4096 +240,1536,4096 +240,14336,4096 +240,4096,1024 +240,4096,14336 +248,4096,1024 +248,14336,4096 +248,4096,14336 +248,1536,4096 +256,1536,4096 +256,14336,4096 +256,4096,14336 +256,4096,1024 +512,14336,4096 +512,1536,4096 +512,4096,14336 +512,4096,1024 +1024,4096,1024 +1024,14336,4096 +1024,1536,4096 +1024,4096,14336 +1536,1536,4096 +1536,4096,14336 +1536,4096,1024 +1536,14336,4096 +2048,14336,4096 +2048,4096,1024 +2048,4096,14336 +2048,1536,4096 +3072,4096,1024 +3072,14336,4096 +3072,4096,14336 +3072,1536,4096 +4096,4096,14336 +4096,1536,4096 +4096,14336,4096 +4096,4096,1024 +8192,1536,4096 +8192,4096,14336 +8192,4096,1024 +8192,14336,4096 +16384,4096,1024 +16384,14336,4096 +16384,4096,14336 +16384,1536,4096 +18432,1536,4096 +18432,4096,14336 +18432,4096,1024 +18432,14336,4096 +20480,4096,1024 +20480,14336,4096 +20480,4096,14336 +20480,1536,4096 +32768,1536,4096 +32768,4096,1024 +32768,14336,4096 +32768,4096,14336 diff --git a/aiter/configs/a8w8_gemm_model_config/Mixtral-8x7B-Instruct-v0.1-TP8.csv b/aiter/configs/a8w8_gemm_model_config/Mixtral-8x7B-Instruct-v0.1-TP8.csv new file mode 100644 index 00000000..3a337241 --- /dev/null +++ b/aiter/configs/a8w8_gemm_model_config/Mixtral-8x7B-Instruct-v0.1-TP8.csv @@ -0,0 +1,189 @@ +M,N,K +1,14336,4096 +1,4096,512 +1,4096,14336 +1,768,4096 +2,14336,4096 +2,4096,14336 +2,4096,512 +2,768,4096 +4,14336,4096 +4,4096,14336 +4,768,4096 +4,4096,512 +8,768,4096 +8,14336,4096 +8,4096,512 +8,4096,14336 +16,14336,4096 +16,4096,512 +16,4096,14336 +16,768,4096 +17,14336,4096 +17,4096,14336 +17,4096,512 +17,768,4096 +24,14336,4096 +24,4096,14336 +24,4096,512 +24,768,4096 +32,14336,4096 +32,4096,14336 +32,4096,512 +32,768,4096 +40,768,4096 +40,14336,4096 +40,4096,14336 +40,4096,512 +48,4096,512 +48,768,4096 +48,14336,4096 +48,4096,14336 +56,768,4096 +56,4096,512 +56,14336,4096 +56,4096,14336 +64,768,4096 +64,4096,14336 +64,4096,512 +64,14336,4096 +72,768,4096 +72,4096,512 +72,14336,4096 +72,4096,14336 +80,14336,4096 +80,4096,14336 +80,768,4096 +80,4096,512 +88,768,4096 +88,14336,4096 +88,4096,14336 +88,4096,512 +96,768,4096 +96,4096,512 +96,14336,4096 +96,4096,14336 +104,14336,4096 +104,4096,512 +104,4096,14336 +104,768,4096 +112,14336,4096 +112,4096,14336 +112,4096,512 +112,768,4096 +120,4096,14336 +120,4096,512 +120,768,4096 +120,14336,4096 +128,14336,4096 +128,768,4096 +128,4096,14336 +128,4096,512 +136,14336,4096 +136,4096,14336 +136,768,4096 +136,4096,512 +144,768,4096 +144,4096,512 +144,14336,4096 +144,4096,14336 +152,14336,4096 +152,4096,14336 +152,4096,512 +152,768,4096 +160,4096,14336 +160,4096,512 +160,14336,4096 +160,768,4096 +168,4096,512 +168,14336,4096 +168,4096,14336 +168,768,4096 +176,768,4096 +176,4096,512 +176,14336,4096 +176,4096,14336 +184,14336,4096 +184,4096,14336 +184,4096,512 +184,768,4096 +192,768,4096 +192,4096,14336 +192,4096,512 +192,14336,4096 +200,4096,512 +200,14336,4096 +200,4096,14336 +200,768,4096 +208,14336,4096 +208,4096,14336 +208,768,4096 +208,4096,512 +216,4096,14336 +216,4096,512 +216,14336,4096 +216,768,4096 +224,768,4096 +224,14336,4096 +224,4096,14336 +224,4096,512 +232,768,4096 +232,4096,14336 +232,4096,512 +232,14336,4096 +240,768,4096 +240,14336,4096 +240,4096,512 +240,4096,14336 +248,4096,512 +248,14336,4096 +248,4096,14336 +248,768,4096 +256,14336,4096 +256,4096,14336 +256,4096,512 +256,768,4096 +512,14336,4096 +512,768,4096 +512,4096,14336 +512,4096,512 +1024,4096,512 +1024,14336,4096 +1024,768,4096 +1024,4096,14336 +1536,4096,14336 +1536,4096,512 +1536,14336,4096 +1536,768,4096 +2048,14336,4096 +2048,4096,512 +2048,4096,14336 +2048,768,4096 +3072,4096,512 +3072,14336,4096 +3072,4096,14336 +3072,768,4096 +4096,4096,14336 +4096,768,4096 +4096,14336,4096 +4096,4096,512 +8192,4096,14336 +8192,4096,512 +8192,14336,4096 +8192,768,4096 +16384,14336,4096 +16384,4096,14336 +16384,768,4096 +16384,4096,512 +18432,4096,14336 +18432,4096,512 +18432,14336,4096 +18432,768,4096 +20480,14336,4096 +20480,4096,14336 +20480,768,4096 +20480,4096,512 +32768,768,4096 +32768,4096,512 +32768,14336,4096 +32768,4096,14336 diff --git a/aiter/test_common.py b/aiter/test_common.py index b78a8dc6..266a01fc 100644 --- a/aiter/test_common.py +++ b/aiter/test_common.py @@ -1,12 +1,18 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - -import torch -import torch.profiler as tpf -import os +from aiter import logger import numpy as np +import os import pandas as pd -from aiter import logger +import torch +import torch.profiler as tpf +from typing import TypeVar + + +T = TypeVar("T") + +MAX_RAND_INT = 20 +MIN_RAND_INT = -20 def perftest(num_iters=101, num_warmup=10, testGraph=False): @@ -155,3 +161,83 @@ def tensor_load(filename: str): metafile = '.'.join(filename.split('.')[:-1])+'.meta' shape, dtype = [eval(line.strip()) for line in open(metafile)] return torch.tensor(DWs).view(dtype).view(shape) + + +def rand_tensor( + size: torch.Size, + dtype: torch.dtype +) -> torch.tensor: + """ + Generate a random PyTorch tensor with specified shape and data type. + + - For integer types: Uses torch.randint to generate random integers within a fixed range. + - For float types: Uses torch.rand to generate random floats between 0 and 1. + + Parameters + ---------- + size: torch.Size + The size of the generated tensor. + dtype: torch.dtype + The data type of the generated tensor. + + Returns + ------- + torch.Tensor : A random tensor of the specified shape and data type. + + Raises + ------ + ValueError + If an unsupported data type is provided. + """ + if dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: + # For integer types, use randint + return torch.randint(MIN_RAND_INT, MAX_RAND_INT, size=size, dtype=dtype) + elif dtype in [torch.float16, torch.float32, torch.float64, torch.bfloat16]: + # For float types, use rand + return torch.rand(size=size, dtype=dtype) + elif dtype == torch.float8_e4m3fnuz: + # Special case for float8_e4m3fnuz + return torch.rand(size=size, dtype=torch.float16).to(torch.float8_e4m3fnuz) + + raise ValueError(f"Unsupported dtype: {dtype}") + + +def assert_all_close( + a: torch.tensor, + b: torch.tensor, + rtol: float, + atol: float, +) -> None: + """ + Check if all elements in two tensors are close within specified tolerances. + + Parameters + ---------- + a: torch.tensor + First input tensor. + b: torch.tensor + Second input tensor. + rtol: float + Relative tolerence. + atol: float + Absolute tolerence. + + Raises + ------- + AssertionError + If any elements in 'a' and 'b' are not close within the specified tolerances. + The error message includes details about the maximum and average delta, + and the percentage of elements that are not close. + """ + is_close = torch.isclose(a, b, rtol=rtol, atol=atol) + is_not_close = ~is_close + num_not_close = is_not_close.sum() + delta = (a-b)[is_not_close] + percent = num_not_close/a.numel() + message = "" if num_not_close == 0 else f""" +check_all_close failed! +max delta:{delta.max()} +average delta:{delta.mean()} +delta details: {percent:.1%} ({num_not_close} of {a.numel()}) elements + """ + assert is_close.all(), message diff --git a/csrc/ck_gemm_a8w8/README.md b/csrc/ck_gemm_a8w8/README.md index 884a7d94..e220a8a3 100644 --- a/csrc/ck_gemm_a8w8/README.md +++ b/csrc/ck_gemm_a8w8/README.md @@ -15,4 +15,38 @@ You can find the results of the tuning in `aiter/configs/a8w8_tuned_gemm.csv`. ## More If you want to re-install gemm_a8w8, you should remove `aiter/jit/module_gemm_a8w8.so` and `aiter/jit/build/module_gemm_a8w8` first. -If you use flag `PREBUILD_KERNELS=1 USE_CK_A8W8=1` when you install aiter, it will build gemm a8w8 kernels in `aiter/configs/a8w8_tuned_gemm.csv` by default. If you want to use the new result of gemm_a8w8_tune, please remove `build` and `*.so` first, then re-intall aiter after finishing tune. \ No newline at end of file +If you use flag `PREBUILD_KERNELS=1 USE_CK_A8W8=1` when you install aiter, it will build gemm a8w8 kernels in `aiter/configs/a8w8_tuned_gemm.csv` by default. If you want to use the new result of gemm_a8w8_tune, please remove `build` and `*.so` first, then re-intall aiter after finishing tune. + + +## FP8 A8W8 Rowwise Scaling GEMM + +The following steps will walk you through the full process of getting the best performance out of your hardware. + +0. Clear gemm_a8w8: Remove `aiter/jit/module_gemm_a8w8.so` and `aiter/jit/build/module_gemm_a8w8`. +```bash +rm -rf aiter/jit/module_gemm_a8w8.so +rm -rf aiter/jit/build/module_gemm_a8w8 +``` + +1. Install the AITER library +```bash +python3 setup.py develop +``` + +2. Tune your GEMM kernel + +python3 csrc/ck_gemm_a8w8/gemm_a8w8_tune.py -i aiter/configs/a8w8_gemm_model_config/Mixtral-8x7B-Instruct-v0.1-TP1.csv -o aiter/configs/a8w8_tuned_gemm.csv --dtype fp8 -k + +3. Use the operator: +```python +from aiter.ops.gemm_op_a8w8 import gemm_a8w8_CK + +output = gemm_a8w8_CK( + qinput, # [M, K] + weight, # [N, K] + x_scale, # [M, 1] + weight_scale, # [1, N] + bias, + dtype=out_dtype # torch.bfloat16, torch.float16 +) +``` diff --git a/csrc/ck_gemm_a8w8/gemm_a8w8.cu b/csrc/ck_gemm_a8w8/gemm_a8w8.cu index 00fcf0c3..e48b90a8 100644 --- a/csrc/ck_gemm_a8w8/gemm_a8w8.cu +++ b/csrc/ck_gemm_a8w8/gemm_a8w8.cu @@ -31,34 +31,82 @@ using RowwiseKernelMap = std::unordered_map< RowwiseKernel, IntTupleHash>; -template +template < + typename ADataType, + typename BDataType, + typename AccDataType, + typename CShuffleDataType, + typename ComputeDataType, + typename DDataType, + typename EDataType +> RowwiseKernel rowwise_heuristic_dispatch(int M, int N, int K) { // Apply shape heuristics to find a suitable kernel implementation. if (M < 64 && N < 2048 && K < 2048) { // Kernel that generally works well on small shapes. - return a8w8_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2; + return a8w8_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2< + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ComputeDataType, + DDataType, + EDataType + >; } else if (M < 64 && K < 2048) { // Kernel that works well for small batch size and small K. - return a8w8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2; + return a8w8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2< + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ComputeDataType, + DDataType, + EDataType + >; } else if (M < 64 && N < 2048) { // Kernel that works well for small batch size and small N. - return a8w8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2; + return a8w8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2< + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ComputeDataType, + DDataType, + EDataType + >; } else if (M < 64 && N > 2048 && K > 2048) { // Kernel that works well for small M but larger N and K. - return a8w8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x16x1x4_4x4x1_1x1_intrawave_v1; + return a8w8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x16x1x4_4x4x1_1x1_intrawave_v1< + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ComputeDataType, + DDataType, + EDataType + >; } else if (M < 64) { // Fallback to generic small batch kernel if we cant find a good match. - return a8w8_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2; + return a8w8_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2< + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ComputeDataType, + DDataType, + EDataType + >; /* } else if (((M < 512 && K < 8192) || (N <= 2048 && K <= 8192) || (K <= 2048 && N <= 8192)) && K >= 1024) { // Kernel that is optimized for larger batch sizes but otherwise small // tensors. @@ -67,23 +115,54 @@ RowwiseKernel rowwise_heuristic_dispatch(int M, int N, int K) else if (K < 1024) { // Special case for small K. - return a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1; + return a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1< + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ComputeDataType, + DDataType, + EDataType + >; } else if (M < 1024) { // Kernel for generic medium batch sizes. - return a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; + return a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3< + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ComputeDataType, + DDataType, + EDataType + >; } else if (M >= 1024 && N >= 1024 && K >= 1024) { // Kernel for very large gemm // return a8w8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3; - return a8w8_rowwise_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1; + return a8w8_rowwise_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1< + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ComputeDataType, + DDataType, + EDataType + >; } else { // Fallback large kernel. - return a8w8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3; + return a8w8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3; } } @@ -95,23 +174,51 @@ static constexpr int nextPow2(unsigned int num) return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); } -template +template < + typename ADataType, + typename BDataType, + typename AccDataType, + typename CShuffleDataType, + typename ComputeDataType, + typename DDataType, + typename EDataType +> RowwiseKernel rowwise_dispatch(int M, int N, int K) { // For a given shape, either find the best kernel via lookup or heuristic. // For many small M shapes, we bucket them to the next largest kernel. // This is fine since kernels are padded anyway. - static const auto lookup = [] { + if constexpr (std::is_same_v) { - return RowwiseKernelMap{GENERATE_LOOKUP_TABLE(DDataType,F16)}; + return RowwiseKernelMap{GENERATE_LOOKUP_TABLE( + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ComputeDataType, + DDataType, + F16 + ) + }; } else if constexpr (std::is_same_v) { - return RowwiseKernelMap{GENERATE_LOOKUP_TABLE(DDataType,B16)}; - } else { + return RowwiseKernelMap{GENERATE_LOOKUP_TABLE( + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ComputeDataType, + DDataType, + B16 + ) + }; + } + else { static_assert(false, "rowwise_dispatch used with unsupported dtype!"); - } }(); - + } + }(); + // First check if this shape(M,N,K) is available in the direct lookup. auto it = lookup.find({M, N, K}); // If we found an optimal kernel, use it. @@ -141,7 +248,15 @@ RowwiseKernel rowwise_dispatch(int M, int N, int K) return it->second; } // Otherwise, use heuristics. - return rowwise_heuristic_dispatch(M, N, K); + return rowwise_heuristic_dispatch< + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ComputeDataType, + DDataType, + EDataType + >(M, N, K); } torch::Tensor gemm_a8w8( @@ -153,10 +268,10 @@ torch::Tensor gemm_a8w8( std::optional bias, int splitK) { - TORCH_CHECK(XQ.dtype() == at::ScalarType::Char && XQ.dtype() == WQ.dtype(), - "Weights and activations should both be int8!"); - TORCH_CHECK(x_scale.dtype() == w_scale.dtype(), - "Scales should have the same dtype!"); + const auto is_int8 = XQ.dtype() == at::ScalarType::Char && XQ.dtype() == WQ.dtype(); + const auto is_fp8 = XQ.dtype() == at::ScalarType::Float8_e4m3fnuz && XQ.dtype() == WQ.dtype(); + TORCH_CHECK(is_int8 || is_fp8, "Weights and activations should both be int8 or fp8!"); + TORCH_CHECK(x_scale.dtype() == w_scale.dtype(), "Scales should have the same dtype!"); if (bias != std::nullopt) TORCH_CHECK(bias.value().dtype() == Y.dtype(), "Out amd bias should have the same dtype!"); @@ -165,26 +280,50 @@ torch::Tensor gemm_a8w8( int N = WQ.size(0); int K = XQ.size(1); int KBatch = std::pow(2, splitK); - - if (x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::Half) - { - rowwise_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); - } - else if (x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::BFloat16) - { - rowwise_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); - } - else if (Y.dtype() == at::ScalarType::Half) - { - rowwise_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); - } - else if (Y.dtype() == at::ScalarType::BFloat16) - { - rowwise_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); - } - else - { - TORCH_CHECK(false, "Unsupported scales/output dtype!"); + // TODO: simplify + if(is_int8){ + if (x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::Half) + { + rowwise_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); + } + else if (x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::BFloat16) + { + rowwise_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); + } + else if (Y.dtype() == at::ScalarType::Half) + { + rowwise_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); + } + else if (Y.dtype() == at::ScalarType::BFloat16) + { + rowwise_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); + } + else + { + TORCH_CHECK(false, "Unsupported scales/output dtype!"); + } + } else { + if (x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::Half) + { + rowwise_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); + } + else if (x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::BFloat16) + { + rowwise_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); + } + else if (Y.dtype() == at::ScalarType::Half) + { + rowwise_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); + } + else if (Y.dtype() == at::ScalarType::BFloat16) + { + rowwise_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); + } + else + { + TORCH_CHECK(false, "Unsupported scales/output dtype!"); + } } + return Y; } diff --git a/csrc/ck_gemm_a8w8/gemm_a8w8_common.py b/csrc/ck_gemm_a8w8/gemm_a8w8_common.py index 279e21fd..4c484caf 100644 --- a/csrc/ck_gemm_a8w8/gemm_a8w8_common.py +++ b/csrc/ck_gemm_a8w8/gemm_a8w8_common.py @@ -3,7 +3,7 @@ from dataclasses import dataclass @dataclass -class kernelInstance: +class KernelParameters: BLOCK_SIZE: int MPerBLOCK: int NPerBLOCK: int @@ -38,115 +38,116 @@ def name(self) -> str: ("x").join(map(lambda x: str(x), [ self.CSHUFFLE_MX_PER_WAVE_PERSHUFFLE, self.CSHUFFLE_NX_PER_WAVE_PERSHUFFLE])), self.LOOP_SCHED.lower(), - f"v{self.PIPELINE_VERSION}" + f"v{self.PIPELINE_VERSION}", + ]) -kernels_list = { +kernels_params_dict = { # id: kernel: BLOCK_SIZE| MPerBLOCK| NPerBLOCK| KPerBLOCK| WAVE_TILE_M| WAVE_TILE_N| WAVE_MAP_M| WAVE_MAP_N| ABLOCK_TRANSFER| BBLOCK_TRANSFER| CBLOCK_TRANSFER| CBLOCK_SPV| CSHUFFLE_MX| CSHUFFLE_NX| LOOP_SCHED| PIPELINE_VERSION - 0: kernelInstance( 256, 256, 256, 64, 32, 32, 4, 4, [4, 64, 1], [4, 64, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 4), - 1: kernelInstance( 256, 256, 256, 128, 32, 32, 4, 4, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - 2: kernelInstance( 256, 256, 224, 128, 32, 32, 2, 7, [8, 32, 1], [8, 32, 1], [1, 64, 1, 4], [8, 8, 1], 2, 1, "Intrawave", 3), - 3: kernelInstance( 256, 256, 192, 128, 32, 32, 4, 3, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - 4: kernelInstance( 256, 256, 160, 128, 32, 32, 2, 5, [8, 32, 1], [8, 32, 1], [1, 64, 1, 4], [8, 8, 1], 2, 1, "Intrawave", 3), - 5: kernelInstance( 256, 256, 128, 128, 32, 32, 4, 2, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - 6: kernelInstance( 256, 256, 96, 128, 32, 32, 2, 3, [8, 32, 1], [8, 32, 1], [1, 64, 1, 4], [8, 8, 1], 2, 1, "Intrawave", 3), - 7: kernelInstance( 256, 256, 64, 128, 32, 32, 4, 1, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - 8: kernelInstance( 256, 128, 256, 128, 32, 32, 2, 4, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - 9: kernelInstance( 256, 128, 224, 128, 32, 32, 1, 7, [8, 32, 1], [8, 32, 1], [1, 64, 1, 4], [8, 8, 1], 1, 1, "Intrawave", 3), - 10: kernelInstance( 256, 128, 192, 128, 32, 32, 2, 3, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - 11: kernelInstance( 256, 128, 160, 128, 32, 32, 1, 5, [8, 32, 1], [8, 32, 1], [1, 64, 1, 4], [8, 8, 1], 1, 1, "Intrawave", 3), - 12: kernelInstance( 256, 128, 128, 256, 32, 32, 2, 2, [16, 16, 1], [16, 16, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - 13: kernelInstance( 256, 128, 128, 128, 32, 32, 2, 2, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - 14: kernelInstance( 256, 128, 96, 256, 32, 32, 1, 3, [16, 16, 1], [16, 16, 1], [1, 64, 1, 4], [8, 8, 1], 1, 1, "Intrawave", 3), - 15: kernelInstance( 256, 128, 64, 256, 32, 32, 2, 1, [16, 16, 1], [16, 16, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - 16: kernelInstance( 256, 64, 256, 128, 32, 32, 1, 4, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - 17: kernelInstance( 256, 64, 224, 128, 16, 16, 2, 7, [8, 32, 1], [8, 32, 1], [1, 64, 1, 4], [8, 8, 1], 2, 1, "Intrawave", 3), - 18: kernelInstance( 256, 64, 192, 256, 32, 32, 1, 3, [16, 16, 1], [16, 16, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - 19: kernelInstance( 256, 64, 192, 128, 32, 32, 1, 3, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - 20: kernelInstance( 256, 64, 160, 256, 16, 16, 2, 5, [16, 16, 1], [16, 16, 1], [1, 64, 1, 4], [8, 8, 1], 2, 1, "Intrawave", 3), - 21: kernelInstance( 256, 64, 128, 256, 32, 32, 1, 2, [16, 16, 1], [16, 16, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - 22: kernelInstance( 256, 64, 96, 256, 16, 16, 2, 3, [16, 16, 1], [16, 16, 1], [1, 64, 1, 4], [8, 8, 1], 2, 1, "Intrawave", 3), - 23: kernelInstance( 256, 64, 64, 512, 32, 32, 1, 1, [32, 8, 1], [32, 8, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - 24: kernelInstance( 256, 32, 256, 128, 32, 32, 1, 2, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - 25: kernelInstance( 256, 32, 224, 256, 16, 16, 1, 7, [16, 16, 1], [16, 16, 1], [1, 32, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 3), - 26: kernelInstance( 256, 32, 192, 256, 16, 16, 1, 6, [16, 16, 1], [16, 16, 1], [1, 32, 1, 8], [8, 8, 1], 1, 2, "Intrawave", 3), - 27: kernelInstance( 256, 32, 160, 256, 16, 16, 1, 5, [16, 16, 1], [16, 16, 1], [1, 32, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 3), - 28: kernelInstance( 256, 32, 128, 256, 32, 32, 1, 1, [16, 16, 1], [16, 16, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - 29: kernelInstance( 256, 32, 96, 256, 16, 16, 1, 3, [16, 16, 1], [16, 16, 1], [1, 32, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 3), - 30: kernelInstance( 256, 32, 64, 512, 16, 16, 1, 2, [32, 8, 1], [32, 8, 1], [1, 32, 1, 8], [8, 8, 1], 1, 2, "Intrawave", 3), - 31: kernelInstance( 256, 16, 256, 128, 16, 16, 1, 4, [16, 16, 1], [8, 32, 1], [1, 16, 1, 16], [8, 8, 1], 1, 2, "Intrawave", 3), - 32: kernelInstance( 256, 16, 192, 256, 16, 16, 1, 3, [16, 16, 1], [16, 16, 1], [1, 16, 1, 16], [4, 4, 1], 1, 1, "Intrawave", 3), - 33: kernelInstance( 256, 16, 128, 256, 16, 16, 1, 2, [16, 16, 1], [16, 16, 1], [1, 16, 1, 16], [8, 8, 1], 1, 2, "Intrawave", 3), - 34: kernelInstance( 256, 16, 64, 512, 16, 16, 1, 1, [32, 8, 1], [32, 8, 1], [1, 16, 1, 16], [4, 4, 1], 1, 1, "Intrawave", 3), - 35: kernelInstance( 256, 128, 128, 128, 32, 32, 2, 2, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 4), - 36: kernelInstance( 256, 128, 128, 64, 32, 32, 2, 2, [4, 64, 1], [4, 64, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 4), - 37: kernelInstance( 256, 256, 256, 128, 16, 16, 8, 8, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 2, "Intrawave", 3), - 38: kernelInstance( 256, 256, 256, 64, 16, 16, 8, 8, [4, 64, 1], [4, 64, 1], [1, 32, 1, 8], [8, 8, 1], 1, 2, "Intrawave", 3), - 39: kernelInstance( 256, 224, 256, 128, 16, 16, 7, 8, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 2, "Intrawave", 3), - 40: kernelInstance( 256, 256, 224, 128, 16, 16, 8, 7, [8, 32, 1], [8, 32, 1], [1, 64, 1, 4], [8, 8, 1], 2, 1, "Intrawave", 3), - 41: kernelInstance( 256, 128, 128, 128, 32, 32, 2, 2, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 5), - 42: kernelInstance( 256, 128, 256, 64, 32, 32, 2, 4, [4, 64, 1], [4, 64, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Interwave", 1), - 43: kernelInstance( 256, 256, 128, 64, 32, 32, 4, 2, [4, 64, 1], [4, 64, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Interwave", 1), - 44: kernelInstance( 256, 128, 128, 128, 32, 32, 2, 2, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Interwave", 1), - 45: kernelInstance( 256, 128, 64, 128, 32, 32, 2, 1, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - 46: kernelInstance( 256, 64, 128, 128, 32, 32, 1, 2, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - 47: kernelInstance( 256, 64, 64, 128, 32, 32, 1, 1, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 0: KernelParameters( 256, 256, 256, 64, 32, 32, 4, 4, [4, 64, 1], [4, 64, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 4), + 1: KernelParameters( 256, 256, 256, 128, 32, 32, 4, 4, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 2: KernelParameters( 256, 256, 224, 128, 32, 32, 2, 7, [8, 32, 1], [8, 32, 1], [1, 64, 1, 4], [8, 8, 1], 2, 1, "Intrawave", 3), + 3: KernelParameters( 256, 256, 192, 128, 32, 32, 4, 3, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 4: KernelParameters( 256, 256, 160, 128, 32, 32, 2, 5, [8, 32, 1], [8, 32, 1], [1, 64, 1, 4], [8, 8, 1], 2, 1, "Intrawave", 3), + 5: KernelParameters( 256, 256, 128, 128, 32, 32, 4, 2, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 6: KernelParameters( 256, 256, 96, 128, 32, 32, 2, 3, [8, 32, 1], [8, 32, 1], [1, 64, 1, 4], [8, 8, 1], 2, 1, "Intrawave", 3), + 7: KernelParameters( 256, 256, 64, 128, 32, 32, 4, 1, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 8: KernelParameters( 256, 128, 256, 128, 32, 32, 2, 4, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 9: KernelParameters( 256, 128, 224, 128, 32, 32, 1, 7, [8, 32, 1], [8, 32, 1], [1, 64, 1, 4], [8, 8, 1], 1, 1, "Intrawave", 3), + 10: KernelParameters( 256, 128, 192, 128, 32, 32, 2, 3, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 11: KernelParameters( 256, 128, 160, 128, 32, 32, 1, 5, [8, 32, 1], [8, 32, 1], [1, 64, 1, 4], [8, 8, 1], 1, 1, "Intrawave", 3), + 12: KernelParameters( 256, 128, 128, 256, 32, 32, 2, 2, [16, 16, 1], [16, 16, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 13: KernelParameters( 256, 128, 128, 128, 32, 32, 2, 2, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 14: KernelParameters( 256, 128, 96, 256, 32, 32, 1, 3, [16, 16, 1], [16, 16, 1], [1, 64, 1, 4], [8, 8, 1], 1, 1, "Intrawave", 3), + 15: KernelParameters( 256, 128, 64, 256, 32, 32, 2, 1, [16, 16, 1], [16, 16, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 16: KernelParameters( 256, 64, 256, 128, 32, 32, 1, 4, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 17: KernelParameters( 256, 64, 224, 128, 16, 16, 2, 7, [8, 32, 1], [8, 32, 1], [1, 64, 1, 4], [8, 8, 1], 2, 1, "Intrawave", 3), + 18: KernelParameters( 256, 64, 192, 256, 32, 32, 1, 3, [16, 16, 1], [16, 16, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 19: KernelParameters( 256, 64, 192, 128, 32, 32, 1, 3, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 20: KernelParameters( 256, 64, 160, 256, 16, 16, 2, 5, [16, 16, 1], [16, 16, 1], [1, 64, 1, 4], [8, 8, 1], 2, 1, "Intrawave", 3), + 21: KernelParameters( 256, 64, 128, 256, 32, 32, 1, 2, [16, 16, 1], [16, 16, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 22: KernelParameters( 256, 64, 96, 256, 16, 16, 2, 3, [16, 16, 1], [16, 16, 1], [1, 64, 1, 4], [8, 8, 1], 2, 1, "Intrawave", 3), + 23: KernelParameters( 256, 64, 64, 512, 32, 32, 1, 1, [32, 8, 1], [32, 8, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 24: KernelParameters( 256, 32, 256, 128, 32, 32, 1, 2, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 25: KernelParameters( 256, 32, 224, 256, 16, 16, 1, 7, [16, 16, 1], [16, 16, 1], [1, 32, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 3), + 26: KernelParameters( 256, 32, 192, 256, 16, 16, 1, 6, [16, 16, 1], [16, 16, 1], [1, 32, 1, 8], [8, 8, 1], 1, 2, "Intrawave", 3), + 27: KernelParameters( 256, 32, 160, 256, 16, 16, 1, 5, [16, 16, 1], [16, 16, 1], [1, 32, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 3), + 28: KernelParameters( 256, 32, 128, 256, 32, 32, 1, 1, [16, 16, 1], [16, 16, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 29: KernelParameters( 256, 32, 96, 256, 16, 16, 1, 3, [16, 16, 1], [16, 16, 1], [1, 32, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 3), + 30: KernelParameters( 256, 32, 64, 512, 16, 16, 1, 2, [32, 8, 1], [32, 8, 1], [1, 32, 1, 8], [8, 8, 1], 1, 2, "Intrawave", 3), + 31: KernelParameters( 256, 16, 256, 128, 16, 16, 1, 4, [16, 16, 1], [8, 32, 1], [1, 16, 1, 16], [8, 8, 1], 1, 2, "Intrawave", 3), + 32: KernelParameters( 256, 16, 192, 256, 16, 16, 1, 3, [16, 16, 1], [16, 16, 1], [1, 16, 1, 16], [4, 4, 1], 1, 1, "Intrawave", 3), + 33: KernelParameters( 256, 16, 128, 256, 16, 16, 1, 2, [16, 16, 1], [16, 16, 1], [1, 16, 1, 16], [8, 8, 1], 1, 2, "Intrawave", 3), + 34: KernelParameters( 256, 16, 64, 512, 16, 16, 1, 1, [32, 8, 1], [32, 8, 1], [1, 16, 1, 16], [4, 4, 1], 1, 1, "Intrawave", 3), + 35: KernelParameters( 256, 128, 128, 128, 32, 32, 2, 2, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 4), + 36: KernelParameters( 256, 128, 128, 64, 32, 32, 2, 2, [4, 64, 1], [4, 64, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 4), + 37: KernelParameters( 256, 256, 256, 128, 16, 16, 8, 8, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 2, "Intrawave", 3), + 38: KernelParameters( 256, 256, 256, 64, 16, 16, 8, 8, [4, 64, 1], [4, 64, 1], [1, 32, 1, 8], [8, 8, 1], 1, 2, "Intrawave", 3), + 39: KernelParameters( 256, 224, 256, 128, 16, 16, 7, 8, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 2, "Intrawave", 3), + 40: KernelParameters( 256, 256, 224, 128, 16, 16, 8, 7, [8, 32, 1], [8, 32, 1], [1, 64, 1, 4], [8, 8, 1], 2, 1, "Intrawave", 3), + 41: KernelParameters( 256, 128, 128, 128, 32, 32, 2, 2, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 5), + 42: KernelParameters( 256, 128, 256, 64, 32, 32, 2, 4, [4, 64, 1], [4, 64, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Interwave", 1), + 43: KernelParameters( 256, 256, 128, 64, 32, 32, 4, 2, [4, 64, 1], [4, 64, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Interwave", 1), + 44: KernelParameters( 256, 128, 128, 128, 32, 32, 2, 2, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Interwave", 1), + 45: KernelParameters( 256, 128, 64, 128, 32, 32, 2, 1, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 46: KernelParameters( 256, 64, 128, 128, 32, 32, 1, 2, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + 47: KernelParameters( 256, 64, 64, 128, 32, 32, 1, 1, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), # mem(Intrawave): Latency friendly - 48: kernelInstance( 128, 32, 16, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Intrawave", 1), - 49: kernelInstance( 64, 16, 16, 128, 16, 16, 1, 1, [8, 8, 1], [8, 8, 1], [1, 16, 1, 4], [4, 4, 1], 1, 1, "Intrawave", 1), - 50: kernelInstance( 128, 16, 32, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 1), + 48: KernelParameters( 128, 32, 16, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Intrawave", 1), + 49: KernelParameters( 64, 16, 16, 128, 16, 16, 1, 1, [8, 8, 1], [8, 8, 1], [1, 16, 1, 4], [4, 4, 1], 1, 1, "Intrawave", 1), + 50: KernelParameters( 128, 16, 32, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 1), # mem(Intrawave): Memory friendly, Col - 51: kernelInstance( 256, 256, 32, 128, 32, 32, 2, 1, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 2), - 52: kernelInstance( 256, 256, 16, 128, 16, 16, 4, 1, [8, 32, 1], [8, 16, 1], [1, 32, 1, 8], [2, 2, 1], 1, 1, "Intrawave", 2), - 53: kernelInstance( 128, 128, 32, 128, 32, 32, 2, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 2), - 54: kernelInstance( 128, 128, 16, 128, 16, 16, 4, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Intrawave", 2), - 55: kernelInstance( 128, 64, 32, 128, 32, 32, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 2), - 56: kernelInstance( 128, 64, 16, 128, 16, 16, 2, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Intrawave", 2), - 57: kernelInstance( 128, 32, 16, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Intrawave", 2), - 58: kernelInstance( 64, 16, 16, 64, 16, 16, 1, 1, [4, 16, 1], [4, 16, 1], [1, 16, 1, 4], [4, 4, 1], 1, 1, "Intrawave", 2), - 59: kernelInstance( 64, 16, 16, 128, 16, 16, 1, 1, [8, 8, 1], [8, 8, 1], [1, 16, 1, 4], [4, 4, 1], 1, 1, "Intrawave", 2), - 60: kernelInstance( 128, 16, 32, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 2), - 61: kernelInstance( 128, 16, 64, 128, 16, 16, 1, 2, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 2), - 62: kernelInstance( 128, 32, 64, 128, 32, 32, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 2), - 63: kernelInstance( 128, 16, 128, 128, 16, 16, 1, 4, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Interwave", 2), - 64: kernelInstance( 128, 32, 128, 128, 32, 32, 1, 2, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [8, 8, 1], 1, 1, "Interwave", 2), - 65: kernelInstance( 256, 16, 256, 128, 16, 16, 1, 4, [8, 16, 1], [8, 16, 1], [1, 16, 1, 16], [4, 4, 1], 1, 1, "Interwave", 2), - 66: kernelInstance( 256, 32, 256, 128, 32, 32, 1, 2, [8, 32, 1], [8, 32, 1], [1, 16, 1, 16], [8, 8, 1], 1, 1, "Intrawave", 2), + 51: KernelParameters( 256, 256, 32, 128, 32, 32, 2, 1, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 2), + 52: KernelParameters( 256, 256, 16, 128, 16, 16, 4, 1, [8, 32, 1], [8, 16, 1], [1, 32, 1, 8], [2, 2, 1], 1, 1, "Intrawave", 2), + 53: KernelParameters( 128, 128, 32, 128, 32, 32, 2, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 2), + 54: KernelParameters( 128, 128, 16, 128, 16, 16, 4, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Intrawave", 2), + 55: KernelParameters( 128, 64, 32, 128, 32, 32, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 2), + 56: KernelParameters( 128, 64, 16, 128, 16, 16, 2, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Intrawave", 2), + 57: KernelParameters( 128, 32, 16, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Intrawave", 2), + 58: KernelParameters( 64, 16, 16, 64, 16, 16, 1, 1, [4, 16, 1], [4, 16, 1], [1, 16, 1, 4], [4, 4, 1], 1, 1, "Intrawave", 2), + 59: KernelParameters( 64, 16, 16, 128, 16, 16, 1, 1, [8, 8, 1], [8, 8, 1], [1, 16, 1, 4], [4, 4, 1], 1, 1, "Intrawave", 2), + 60: KernelParameters( 128, 16, 32, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 2), + 61: KernelParameters( 128, 16, 64, 128, 16, 16, 1, 2, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 2), + 62: KernelParameters( 128, 32, 64, 128, 32, 32, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 2), + 63: KernelParameters( 128, 16, 128, 128, 16, 16, 1, 4, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Interwave", 2), + 64: KernelParameters( 128, 32, 128, 128, 32, 32, 1, 2, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [8, 8, 1], 1, 1, "Interwave", 2), + 65: KernelParameters( 256, 16, 256, 128, 16, 16, 1, 4, [8, 16, 1], [8, 16, 1], [1, 16, 1, 16], [4, 4, 1], 1, 1, "Interwave", 2), + 66: KernelParameters( 256, 32, 256, 128, 32, 32, 1, 2, [8, 32, 1], [8, 32, 1], [1, 16, 1, 16], [8, 8, 1], 1, 1, "Intrawave", 2), # mem(Interwave): Latency friendly - 67: kernelInstance( 128, 32, 16, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Interwave", 1), - 68: kernelInstance( 64, 16, 16, 128, 16, 16, 1, 1, [8, 8, 1], [8, 8, 1], [1, 16, 1, 4], [4, 4, 1], 1, 1, "Interwave", 1), - 69: kernelInstance( 128, 16, 32, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Interwave", 1), + 67: KernelParameters( 128, 32, 16, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Interwave", 1), + 68: KernelParameters( 64, 16, 16, 128, 16, 16, 1, 1, [8, 8, 1], [8, 8, 1], [1, 16, 1, 4], [4, 4, 1], 1, 1, "Interwave", 1), + 69: KernelParameters( 128, 16, 32, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Interwave", 1), # mem(Interwave): Memory friendly, Col - 70: kernelInstance( 256, 256, 32, 128, 32, 32, 2, 1, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [4, 4, 1], 1, 1, "Interwave", 2), - 71: kernelInstance( 256, 256, 16, 128, 16, 16, 4, 1, [8, 32, 1], [8, 16, 1], [1, 32, 1, 8], [2, 2, 1], 1, 1, "Interwave", 2), - 72: kernelInstance( 128, 128, 32, 128, 32, 32, 2, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Interwave", 2), - 73: kernelInstance( 128, 128, 16, 128, 16, 16, 4, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Interwave", 2), - 74: kernelInstance( 128, 64, 32, 128, 32, 32, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Interwave", 2), - 75: kernelInstance( 128, 64, 16, 128, 16, 16, 2, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Interwave", 2), - 76: kernelInstance( 128, 32, 16, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Interwave", 2), - 77: kernelInstance( 64, 16, 16, 64, 16, 16, 1, 1, [4, 16, 1], [4, 16, 1], [1, 16, 1, 4], [4, 4, 1], 1, 1, "Interwave", 2), - 78: kernelInstance( 64, 16, 16, 128, 16, 16, 1, 1, [8, 8, 1], [8, 8, 1], [1, 16, 1, 4], [4, 4, 1], 1, 1, "Interwave", 2), - 79: kernelInstance( 128, 16, 32, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Interwave", 2), - 80: kernelInstance( 128, 16, 64, 128, 16, 16, 1, 2, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Interwave", 2), - 81: kernelInstance( 128, 32, 64, 128, 32, 32, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [8, 8, 1], 1, 1, "Interwave", 2), - 82: kernelInstance( 128, 16, 128, 128, 16, 16, 1, 4, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Interwave", 2), - 83: kernelInstance( 128, 32, 128, 128, 32, 32, 1, 2, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [8, 8, 1], 1, 1, "Interwave", 2), - 84: kernelInstance( 256, 16, 256, 128, 16, 16, 1, 4, [8, 16, 1], [8, 16, 1], [1, 16, 1, 16], [4, 4, 1], 1, 1, "Interwave", 2), - 85: kernelInstance( 256, 32, 256, 128, 32, 32, 1, 2, [8, 32, 1], [8, 32, 1], [1, 16, 1, 16], [8, 8, 1], 1, 1, "Interwave", 2), + 70: KernelParameters( 256, 256, 32, 128, 32, 32, 2, 1, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [4, 4, 1], 1, 1, "Interwave", 2), + 71: KernelParameters( 256, 256, 16, 128, 16, 16, 4, 1, [8, 32, 1], [8, 16, 1], [1, 32, 1, 8], [2, 2, 1], 1, 1, "Interwave", 2), + 72: KernelParameters( 128, 128, 32, 128, 32, 32, 2, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Interwave", 2), + 73: KernelParameters( 128, 128, 16, 128, 16, 16, 4, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Interwave", 2), + 74: KernelParameters( 128, 64, 32, 128, 32, 32, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Interwave", 2), + 75: KernelParameters( 128, 64, 16, 128, 16, 16, 2, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Interwave", 2), + 76: KernelParameters( 128, 32, 16, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Interwave", 2), + 77: KernelParameters( 64, 16, 16, 64, 16, 16, 1, 1, [4, 16, 1], [4, 16, 1], [1, 16, 1, 4], [4, 4, 1], 1, 1, "Interwave", 2), + 78: KernelParameters( 64, 16, 16, 128, 16, 16, 1, 1, [8, 8, 1], [8, 8, 1], [1, 16, 1, 4], [4, 4, 1], 1, 1, "Interwave", 2), + 79: KernelParameters( 128, 16, 32, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Interwave", 2), + 80: KernelParameters( 128, 16, 64, 128, 16, 16, 1, 2, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Interwave", 2), + 81: KernelParameters( 128, 32, 64, 128, 32, 32, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [8, 8, 1], 1, 1, "Interwave", 2), + 82: KernelParameters( 128, 16, 128, 128, 16, 16, 1, 4, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Interwave", 2), + 83: KernelParameters( 128, 32, 128, 128, 32, 32, 1, 2, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [8, 8, 1], 1, 1, "Interwave", 2), + 84: KernelParameters( 256, 16, 256, 128, 16, 16, 1, 4, [8, 16, 1], [8, 16, 1], [1, 16, 1, 16], [4, 4, 1], 1, 1, "Interwave", 2), + 85: KernelParameters( 256, 32, 256, 128, 32, 32, 1, 2, [8, 32, 1], [8, 32, 1], [1, 16, 1, 16], [8, 8, 1], 1, 1, "Interwave", 2), } default_kernels_dict = { # ( M, N, K): kernel: BLOCK_SIZE| MPerBLOCK| NPerBLOCK| KPerBLOCK| WAVE_TILE_M| WAVE_TILE_N| WAVE_MAP_M| WAVE_MAP_N| ABLOCK_TRANSFER| BBLOCK_TRANSFER| CBLOCK_TRANSFER| CBLOCK_SPV| CSHUFFLE_MX| CSHUFFLE_NX| LOOP_SCHED|PIPELINE_VERSION - (-1): kernelInstance( 64, 16, 16, 128, 16, 16, 1, 1, [8, 8, 1], [8, 8, 1], [1, 16, 1, 4], [4, 4, 1], 1, 1, "Interwave", 2), - (-3): kernelInstance( 128, 32, 16, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Interwave", 2), - (-4): kernelInstance( 64, 16, 16, 256, 16, 16, 1, 1, [16, 4, 1], [16, 4, 1], [1, 16, 1, 4], [4, 4, 1], 1, 1, "Intrawave", 1), - (-5): kernelInstance( 128, 16, 32, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Intrawave", 2), - (-6): kernelInstance( 256, 128, 128, 128, 32, 32, 2, 2, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Interwave", 1), - (-7): kernelInstance( 256, 128, 128, 128, 32, 32, 2, 2, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), - (-8): kernelInstance( 256, 256, 128, 64, 32, 32, 4, 2, [4, 64, 1], [4, 64, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Interwave", 1), - (-9): kernelInstance( 256, 224, 256, 128, 16, 16, 7, 8, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 2, "Intrawave", 3), - (-10): kernelInstance( 128, 16, 32, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 2), + (-1): KernelParameters( 64, 16, 16, 128, 16, 16, 1, 1, [8, 8, 1], [8, 8, 1], [1, 16, 1, 4], [4, 4, 1], 1, 1, "Interwave", 2), + (-3): KernelParameters( 128, 32, 16, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Interwave", 2), + (-4): KernelParameters( 64, 16, 16, 256, 16, 16, 1, 1, [16, 4, 1], [16, 4, 1], [1, 16, 1, 4], [4, 4, 1], 1, 1, "Intrawave", 1), + (-5): KernelParameters( 128, 16, 32, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [2, 2, 1], 1, 1, "Intrawave", 2), + (-6): KernelParameters( 256, 128, 128, 128, 32, 32, 2, 2, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Interwave", 1), + (-7): KernelParameters( 256, 128, 128, 128, 32, 32, 2, 2, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Intrawave", 3), + (-8): KernelParameters( 256, 256, 128, 64, 32, 32, 4, 2, [4, 64, 1], [4, 64, 1], [1, 32, 1, 8], [8, 8, 1], 1, 1, "Interwave", 1), + (-9): KernelParameters( 256, 224, 256, 128, 16, 16, 7, 8, [8, 32, 1], [8, 32, 1], [1, 32, 1, 8], [8, 8, 1], 1, 2, "Intrawave", 3), + (-10): KernelParameters( 128, 16, 32, 128, 16, 16, 1, 1, [8, 16, 1], [8, 16, 1], [1, 16, 1, 8], [4, 4, 1], 1, 1, "Intrawave", 2), } \ No newline at end of file diff --git a/csrc/ck_gemm_a8w8/gemm_a8w8_tune.cu b/csrc/ck_gemm_a8w8/gemm_a8w8_tune.cu index 4e12136e..ca92ce76 100644 --- a/csrc/ck_gemm_a8w8/gemm_a8w8_tune.cu +++ b/csrc/ck_gemm_a8w8/gemm_a8w8_tune.cu @@ -26,9 +26,16 @@ static constexpr int nextPow2(unsigned int num) return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); } -template -RowwiseKernel rowwise_dispatch(int id) -{ +template < + typename ADataType, + typename BDataType, + typename AccDataType, + typename CShuffleDataType, + typename ComputeDataType, + typename DDataType, + typename EDataType +> +RowwiseKernel rowwise_dispatch(int id){ // For a given shape, either find the best kernel via lookup or heuristic. // For many small M shapes, we bucket them to the next largest kernel. // This is fine since kernels are padded anyway. @@ -37,15 +44,35 @@ RowwiseKernel rowwise_dispatch(int id) static const auto lookup = [] { if constexpr (std::is_same_v) { - return RowwiseKernelMap{GENERATE_LOOKUP_TABLE(DDataType,F16)}; + return RowwiseKernelMap{ + GENERATE_LOOKUP_TABLE( + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ComputeDataType, + DDataType, + F16 + ) + }; } else if constexpr (std::is_same_v) { - return RowwiseKernelMap{GENERATE_LOOKUP_TABLE(DDataType,B16)}; + return RowwiseKernelMap{ + GENERATE_LOOKUP_TABLE( + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ComputeDataType, + DDataType, + B16 + ) + }; } else { - static_assert(false, "rowwise_dispatch used with unsupported dtype!"); - } }(); + static_assert(false, "rowwise_dispatch used with unsupported dtype!"); + } + }(); - TORCH_CHECK(id < lookup.size(), - "Kernel id " + std::to_string(id) +" is out of range!"); + TORCH_CHECK(id < lookup.size(), "Kernel id " + std::to_string(id) +" is out of range!"); auto it = lookup.find(id); // If we found an optimal kernel, use it. if (it != lookup.end()) @@ -57,6 +84,34 @@ RowwiseKernel rowwise_dispatch(int id) } +torch::Tensor gemm_a8w8_int8( + torch::Tensor &XQ, + torch::Tensor &WQ, + torch::Tensor &x_scale, + torch::Tensor &w_scale, + torch::Tensor &Y, + int kernelId, + int KBatch) +{ + std::optional bias = std::nullopt; + rowwise_dispatch(kernelId)(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); + return Y; +} + +torch::Tensor gemm_a8w8_fp8( + torch::Tensor &XQ, + torch::Tensor &WQ, + torch::Tensor &x_scale, + torch::Tensor &w_scale, + torch::Tensor &Y, + int kernelId, + int KBatch) +{ + std::optional bias = std::nullopt; + rowwise_dispatch(kernelId)(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); + return Y; +} + torch::Tensor gemm_a8w8_tune( torch::Tensor &XQ, torch::Tensor &WQ, @@ -64,12 +119,13 @@ torch::Tensor gemm_a8w8_tune( torch::Tensor &w_scale, torch::Tensor &Y, int kernelId, - int splitK) + int splitK + ) { - TORCH_CHECK(XQ.dtype() == at::ScalarType::Char && XQ.dtype() == WQ.dtype(), - "Weights and activations should both be int8!"); - TORCH_CHECK( x_scale.dtype() == w_scale.dtype(), - "Scales should have the same dtype!"); + const auto is_int8 = XQ.dtype() == at::ScalarType::Char && XQ.dtype() == WQ.dtype(); + const auto is_fp8 = XQ.dtype() == at::ScalarType::Float8_e4m3fnuz && XQ.dtype() == WQ.dtype(); + TORCH_CHECK(is_int8 || is_fp8, "Weights and activations should both be either int8 or fp8!"); + TORCH_CHECK( x_scale.dtype() == w_scale.dtype(), "Scales should have the same dtype!"); std::optional bias = std::nullopt; int M = XQ.size(0); @@ -77,26 +133,11 @@ torch::Tensor gemm_a8w8_tune( int K = XQ.size(1); int KBatch = std::pow(2, splitK); - // if (x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::Half) - // { - // rowwise_dispatch(kernelId)(XQ, WQ, x_scale, w_scale, Y, bias); - // } - // else if (x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::BFloat16) - // { - // rowwise_dispatch(kernelId)(XQ, WQ, x_scale, w_scale, Y, bias); - // } - // else if (Y.dtype() == at::ScalarType::Half) - // { - // rowwise_dispatch(kernelId)(XQ, WQ, x_scale, w_scale, Y, bias); - // } - // else - if (Y.dtype() == at::ScalarType::BFloat16) - { - rowwise_dispatch(kernelId)(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); - } - else - { + if(Y.dtype() != at::ScalarType::BFloat16) TORCH_CHECK(false, "Unsupported scales/output dtype!"); - } - return Y; -} + + if(is_fp8) + return gemm_a8w8_fp8(XQ, WQ, x_scale, w_scale, Y, kernelId, KBatch); + + return gemm_a8w8_int8(XQ, WQ, x_scale, w_scale, Y, kernelId, KBatch); +} \ No newline at end of file diff --git a/csrc/ck_gemm_a8w8/gemm_a8w8_tune.py b/csrc/ck_gemm_a8w8/gemm_a8w8_tune.py index 0d7993c4..ccbe9b6c 100644 --- a/csrc/ck_gemm_a8w8/gemm_a8w8_tune.py +++ b/csrc/ck_gemm_a8w8/gemm_a8w8_tune.py @@ -8,7 +8,7 @@ import torch.nn.functional as F import aiter from aiter.test_common import checkAllclose, perftest -from gemm_a8w8_common import kernelInstance, kernels_list +from gemm_a8w8_common import kernels_params_dict import argparse def checkClose(a, b, rtol=1e-3, atol=0.01): @@ -48,11 +48,17 @@ def kernel_instance_test(x, weight, x_scale, w_scale, out, kernel_id, splitK=0): aiter.gemm_a8w8_tune(x, weight, x_scale, w_scale, out, kernel_id, splitK) return out +def random_tensor(a: torch.tensor, b: torch.tensor, dtype: torch.dtype) -> torch.tensor: + if dtype == torch.int8: + return torch.randint(-20, 20, (a, b), dtype=dtype, device="cuda") + elif dtype == torch.float8_e4m3fnuz: + return torch.rand((a, b), device="cuda").to(dtype) + raise RuntimeError("Unsupported data type.") -def tune_gemm(m, n, k, useSplitK = False): +def tune_gemm(m, n, k, dtype: torch.dtype, useSplitK = False): dim = (m, n, k) - x = torch.randint(-20, 20, (m, k), dtype=torch.int8, device="cuda") - weight = torch.randint(-20, 20, (n, k), dtype=torch.int8, device="cuda") + x = random_tensor(m, k, dtype) + weight = random_tensor(n, k, dtype) x_scale = torch.rand([m, 1], dtype=torch.bfloat16, device="cuda") w_scale = torch.rand([1, n], dtype=torch.bfloat16, device="cuda") out = torch.empty(m, n, dtype=torch.bfloat16, device="cuda") @@ -61,11 +67,10 @@ def tune_gemm(m, n, k, useSplitK = False): print(f"*******************M:{m} X N:{n} X K:{k}**************************") print(f"Start tuning a8w8 gemm kernel for M:{m}, N:{n}, K{k}:") - kernels_num = len(kernels_list) best_kernelConfig = (-1, 0) best_time = -1 - for i in range(kernels_num): - kernel = kernels_list[i] + for i in range(len(kernels_params_dict)): + kernel = kernels_params_dict[i] maxsplitK = aiter.compute_gemm_SplitK(m, n, k, kernel.MPerBLOCK, kernel.NPerBLOCK, kernel.KPerBLOCK) \ if useSplitK else 0 for splitK in range(maxsplitK+1): @@ -95,15 +100,16 @@ def tune_gemm(m, n, k, useSplitK = False): return best_kernelId, splitK, best_time -def tune_gemm_list(untunedf, tunedf, issorted = False, useSplitK = False): +def tune_gemm_list(untunedf, tunedf, dtype: torch.dtype, issorted = False, useSplitK = False): + print("untuned df is \n\n", untunedf) for i in range(len(untunedf)): M = untunedf.loc[i, "M"] N = untunedf.loc[i, "N"] K = untunedf.loc[i, "K"] if tunedf[(tunedf["M"]==M) & (tunedf["N"]==N) & (tunedf["K"]==K)].empty: - kernelId, splitK, time = tune_gemm(M, N, K, useSplitK) - kernelName = 'None' if kernelId == -1 else kernels_list[kernelId].name + kernelId, splitK, time = tune_gemm(M, N, K, dtype, useSplitK) + kernelName = 'None' if kernelId == -1 else kernels_params_dict[kernelId].name temp = pd.DataFrame({"M":[M], "N":[N], "K":[K], "kernelId":[kernelId], "splitK":[splitK], "us":[time], "kernelName":[kernelName]}) tunedf = pd.concat([tunedf, temp], ignore_index=True) @@ -150,6 +156,14 @@ def tune_gemm_list(untunedf, tunedf, issorted = False, useSplitK = False): help="Use splitK kernels" ) + parser.add_argument( + "-d", + "--dtype", + required=False, + default="int8", + help="int8 or fp8" + ) + parser.add_argument( "--sort", action='store_true', @@ -158,7 +172,8 @@ def tune_gemm_list(untunedf, tunedf, issorted = False, useSplitK = False): ) args = parser.parse_args() + dtype = torch.float8_e4m3fnuz if args.dtype == "fp8" else torch.int8 untunedf = get_untuned_gemm_list(args.untune_file) tunedf = get_tuned_gemm_list(args.tune_file) - tunedf = tune_gemm_list(untunedf, tunedf, args.sort, args.splitK) + tunedf = tune_gemm_list(untunedf, tunedf, dtype, args.sort, args.splitK) tunedf.to_csv(args.tune_file, index=False) diff --git a/csrc/ck_gemm_a8w8/gen_instances.py b/csrc/ck_gemm_a8w8/gen_instances.py index 347e8b23..9431f3a9 100644 --- a/csrc/ck_gemm_a8w8/gen_instances.py +++ b/csrc/ck_gemm_a8w8/gen_instances.py @@ -1,31 +1,200 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. import os -import sys -from dataclasses import dataclass -import copy from pathlib import Path import pandas as pd import argparse -import shutil -from gemm_a8w8_common import kernelInstance, kernels_list, default_kernels_dict +from gemm_a8w8_common import KernelParameters, kernels_params_dict, default_kernels_dict + +TEMPLATE_PARAMS = "template_params" +INSTANCE_FILE_SUFFIX = "instance_file_suffix" +LOOK_UP_TABLE_HEADER_PATH = "gemm_a8w8_lookup.h" +MANIFEST_HEADER_PATH = "gemm_a8w8_manifest.h" + +PARAM_INSTANCE_SUFFIX_LIST= [ + { + TEMPLATE_PARAMS: "I8, I8, I32, I8, I32, B16, B16", + INSTANCE_FILE_SUFFIX: "I8_I8_I32_I8_I32_BF16_BF16.cpp" + }, + { + TEMPLATE_PARAMS: "I8, I8, I32, I8, I32, F32, B16", + INSTANCE_FILE_SUFFIX: "I8_I8_I32_I8_I32_F32_BF16.cpp" + }, + { + TEMPLATE_PARAMS: "I8, I8, I32, I8, I32, F16, F16", + INSTANCE_FILE_SUFFIX: "I8_I8_I32_I8_I32_F16_F16.cpp" + }, + { + TEMPLATE_PARAMS: "I8, I8, I32, I8, I32, F32, F16", + INSTANCE_FILE_SUFFIX: "I8_I8_I32_I8_I32_F32_F16.cpp" + }, + { + TEMPLATE_PARAMS: "FP8, FP8, F32, FP8, F32, B16, B16", + INSTANCE_FILE_SUFFIX: "FP8_FP8_F32_FP8_F32_B16_B16.cpp" + }, + { + TEMPLATE_PARAMS: "FP8, FP8, F32, FP8, F32, F32, B16", + INSTANCE_FILE_SUFFIX: "FP8_FP8_F32_FP8_F32_F32_B16.cpp" + }, + { + TEMPLATE_PARAMS: "FP8, FP8, F32, FP8, F32, F16, F16", + INSTANCE_FILE_SUFFIX: "FP8_FP8_F32_FP8_F32_F16_F16.cpp" + }, + { + TEMPLATE_PARAMS: "FP8, FP8, F32, FP8, F32, F32, F16", + INSTANCE_FILE_SUFFIX: "FP8_FP8_F32_FP8_F32_F32_F16.cpp" + } +] + +TUNING_PARAM_INSTANCE_SUFFIX_LIST = [ + { + TEMPLATE_PARAMS: "I8, I8, I32, I8, I32, B16, B16", + INSTANCE_FILE_SUFFIX: "I8_I8_I32_I8_I32_BF16_BF16.cpp" + }, + { + TEMPLATE_PARAMS: "FP8, FP8, F32, FP8, F32, B16, B16", + INSTANCE_FILE_SUFFIX: "FP8_FP8_F32_FP8_F32_B16_B16.cpp" + }, +] + +def gen_a8w8_device_gemm_call_skip_bias_branch(k: KernelParameters, gemm_specialization: str) -> str: + return f"""using DeviceGemmInstance = DeviceGemmHelper< + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ComputeDataType, + DDataType, + EDataType, + {k.BLOCK_SIZE}, + {k.MPerBLOCK}, + {k.NPerBLOCK}, + {k.KPerBLOCK}, + {k.WAVE_TILE_M}, + {k.WAVE_TILE_N}, + {k.WAVE_MAP_M}, + {k.WAVE_MAP_N}, + S<{(", ").join(map(lambda x:str(x),k.ABLOCK_TRANSFER))}>, + S<{(", ").join(map(lambda x:str(x),k.BBLOCK_TRANSFER))}>, + S<{(", ").join(map(lambda x:str(x),k.CBLOCK_TRANSFER))}>, + S<{(", ").join(map(lambda x:str(x),k.CBLOCK_SPV))}>, + {k.CSHUFFLE_MX_PER_WAVE_PERSHUFFLE}, + {k.CSHUFFLE_NX_PER_WAVE_PERSHUFFLE}, + ck::BlockGemmPipelineScheduler::{k.LOOP_SCHED}, + ck::BlockGemmPipelineVersion::v{k.PIPELINE_VERSION}, + ck::tensor_operation::device::GemmSpecialization::{gemm_specialization}>; + + return gemm_a8w8_rowwise_impl< + ADataType, + BDataType, + AccDataType, + DDataType, + EDataType, + DeviceGemmInstance + >(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); +""" +def gen_a8w8_device_gemm_call(k: KernelParameters, gemm_specialization: str): + return f"""if (bias != std::nullopt) + {{ + using DeviceGemmInstance = DeviceGemmHelperMMA< + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ComputeDataType, + DDataType, + EDataType, + {k.BLOCK_SIZE}, + {k.MPerBLOCK}, + {k.NPerBLOCK}, + {k.KPerBLOCK}, + {k.WAVE_TILE_M}, + {k.WAVE_TILE_N}, + {k.WAVE_MAP_M}, + {k.WAVE_MAP_N}, + S<{(", ").join(map(lambda x:str(x),k.ABLOCK_TRANSFER))}>, + S<{(", ").join(map(lambda x:str(x),k.BBLOCK_TRANSFER))}>, + S<{(", ").join(map(lambda x:str(x),k.CBLOCK_TRANSFER))}>, + S<{(", ").join(map(lambda x:str(x),k.CBLOCK_SPV))}, {k.CBLOCK_SPV[0]}>, + {k.CSHUFFLE_MX_PER_WAVE_PERSHUFFLE}, + {k.CSHUFFLE_NX_PER_WAVE_PERSHUFFLE}, + ck::BlockGemmPipelineScheduler::{k.LOOP_SCHED}, + ck::BlockGemmPipelineVersion::v{k.PIPELINE_VERSION}, + ck::tensor_operation::device::GemmSpecialization::{gemm_specialization}>; + // Run kernel instance. + + return gemm_a8w8_mma_impl< + ADataType, + BDataType, + DDataType, + EDataType, + DeviceGemmInstance + >(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); + }} + else + {{ + using DeviceGemmInstance = DeviceGemmHelper< + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ComputeDataType, + DDataType, + EDataType, + {k.BLOCK_SIZE}, + {k.MPerBLOCK}, + {k.NPerBLOCK}, + {k.KPerBLOCK}, + {k.WAVE_TILE_M}, + {k.WAVE_TILE_N}, + {k.WAVE_MAP_M}, + {k.WAVE_MAP_N}, + S<{(", ").join(map(lambda x:str(x),k.ABLOCK_TRANSFER))}>, + S<{(", ").join(map(lambda x:str(x),k.BBLOCK_TRANSFER))}>, + S<{(", ").join(map(lambda x:str(x),k.CBLOCK_TRANSFER))}>, + S<{(", ").join(map(lambda x:str(x),k.CBLOCK_SPV))}>, + {k.CSHUFFLE_MX_PER_WAVE_PERSHUFFLE}, + {k.CSHUFFLE_NX_PER_WAVE_PERSHUFFLE}, + ck::BlockGemmPipelineScheduler::{k.LOOP_SCHED}, + ck::BlockGemmPipelineVersion::v{k.PIPELINE_VERSION}, + ck::tensor_operation::device::GemmSpecialization::{gemm_specialization}>; + + return gemm_a8w8_rowwise_impl< + ADataType, + BDataType, + AccDataType, + DDataType, + EDataType, + DeviceGemmInstance + >(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); + }} +""" -class gemm_a8w8_fwd_codegen: - def __init__(self, working_path, istune=False): - self.working_path = working_path - self.impl_path = os.path.join(working_path, "impl") - self.instances_path = os.path.join(working_path, "instances") - self.istune = istune - - def gen_instance(self, k: kernelInstance): - INSTANCE_IMPL = f"""// SPDX-License-Identifier: MIT +def gen_a8w8_implementation(k: KernelParameters, skip_bias_branch: bool) -> str: + if skip_bias_branch: + gemm_a8w8_device_gemm_instance_generator = gen_a8w8_device_gemm_call_skip_bias_branch + else: + gemm_a8w8_device_gemm_instance_generator = gen_a8w8_device_gemm_call + + padding = "MNKPadding" + no_padding = "Default" + return f""" +// SPDX-License-Identifier: MIT // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_a8w8_common.cuh" -template +template < + typename ADataType, + typename BDataType, + typename CShuffleDataType, + typename ComputeDataType, + typename AccDataType, + typename DDataType, + typename EDataType + > torch::Tensor {k.name}( torch::Tensor &XQ, @@ -35,117 +204,32 @@ def gen_instance(self, k: kernelInstance): torch::Tensor &Y, std::optional bias, int KBatch) -{{{{ +{{ // The smallest kernel we have available. Works well for memory bound shapes. // Check if this input needs to be padded. int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); int N = WQ.size(0); int K = WQ.size(1); + bool pad = (M % {k.MPerBLOCK} != 0) || (N % {k.NPerBLOCK} != 0) || (K % ({k.KPerBLOCK} * KBatch) != 0); - if (pad) - {{{{ - // pad - {{INSTANCE_CONTENT_pad}} - // pad - }}}} - else - {{{{ - // no pad - {{INSTANCE_CONTENT_nopad}} - // no pad - }}}} -}}}} - -""" - INSTANCE_CONTENT_bias = f"""if (bias != std::nullopt) - {{{{ - using DeviceGemmInstance = DeviceGemmHelperMMA< - DDataType, EDataType, - {k.BLOCK_SIZE}, - {k.MPerBLOCK}, - {k.NPerBLOCK}, - {k.KPerBLOCK}, - {k.WAVE_TILE_M}, - {k.WAVE_TILE_N}, - {k.WAVE_MAP_M}, - {k.WAVE_MAP_N}, - S<{(", ").join(map(lambda x:str(x),k.ABLOCK_TRANSFER))}>, - S<{(", ").join(map(lambda x:str(x),k.BBLOCK_TRANSFER))}>, - S<{(", ").join(map(lambda x:str(x),k.CBLOCK_TRANSFER))}>, - S<{(", ").join(map(lambda x:str(x),k.CBLOCK_SPV))}, {k.CBLOCK_SPV[0]}>, - {k.CSHUFFLE_MX_PER_WAVE_PERSHUFFLE}, - {k.CSHUFFLE_NX_PER_WAVE_PERSHUFFLE}, - ck::BlockGemmPipelineScheduler::{k.LOOP_SCHED}, - ck::BlockGemmPipelineVersion::v{k.PIPELINE_VERSION}, - ck::tensor_operation::device::GemmSpecialization::{{GemmSpec}}>; - // Run kernel instance. - return gemm_a8w8_mma_impl(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); - }}}} - else - {{{{ - using DeviceGemmInstance = DeviceGemmHelper< - DDataType, EDataType, - {k.BLOCK_SIZE}, - {k.MPerBLOCK}, - {k.NPerBLOCK}, - {k.KPerBLOCK}, - {k.WAVE_TILE_M}, - {k.WAVE_TILE_N}, - {k.WAVE_MAP_M}, - {k.WAVE_MAP_N}, - S<{(", ").join(map(lambda x:str(x),k.ABLOCK_TRANSFER))}>, - S<{(", ").join(map(lambda x:str(x),k.BBLOCK_TRANSFER))}>, - S<{(", ").join(map(lambda x:str(x),k.CBLOCK_TRANSFER))}>, - S<{(", ").join(map(lambda x:str(x),k.CBLOCK_SPV))}>, - {k.CSHUFFLE_MX_PER_WAVE_PERSHUFFLE}, - {k.CSHUFFLE_NX_PER_WAVE_PERSHUFFLE}, - ck::BlockGemmPipelineScheduler::{k.LOOP_SCHED}, - ck::BlockGemmPipelineVersion::v{k.PIPELINE_VERSION}, - ck::tensor_operation::device::GemmSpecialization::{{GemmSpec}}>; - // Run kernel instance. - return gemm_a8w8_rowwise_impl(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); - }}}} -""" - INSTANCE_CONTENT_nobias = f"""using DeviceGemmInstance = DeviceGemmHelper< - DDataType, EDataType, - {k.BLOCK_SIZE}, - {k.MPerBLOCK}, - {k.NPerBLOCK}, - {k.KPerBLOCK}, - {k.WAVE_TILE_M}, - {k.WAVE_TILE_N}, - {k.WAVE_MAP_M}, - {k.WAVE_MAP_N}, - S<{(", ").join(map(lambda x:str(x),k.ABLOCK_TRANSFER))}>, - S<{(", ").join(map(lambda x:str(x),k.BBLOCK_TRANSFER))}>, - S<{(", ").join(map(lambda x:str(x),k.CBLOCK_TRANSFER))}>, - S<{(", ").join(map(lambda x:str(x),k.CBLOCK_SPV))}>, - {k.CSHUFFLE_MX_PER_WAVE_PERSHUFFLE}, - {k.CSHUFFLE_NX_PER_WAVE_PERSHUFFLE}, - ck::BlockGemmPipelineScheduler::{k.LOOP_SCHED}, - ck::BlockGemmPipelineVersion::v{k.PIPELINE_VERSION}, - ck::tensor_operation::device::GemmSpecialization::{{GemmSpec}}>; - // Run kernel instance. - return gemm_a8w8_rowwise_impl(XQ, WQ, x_scale, w_scale, Y, bias, KBatch); -""" - if self.istune: - INSTANCE_IMPL_str = INSTANCE_IMPL.format(INSTANCE_CONTENT_pad=(INSTANCE_CONTENT_nobias.format(GemmSpec="MNKPadding")), - INSTANCE_CONTENT_nopad=(INSTANCE_CONTENT_nobias.format(GemmSpec="Default"))) - else: - INSTANCE_IMPL_str = INSTANCE_IMPL.format(INSTANCE_CONTENT_pad=INSTANCE_CONTENT_bias.format(GemmSpec="MNKPadding"), - INSTANCE_CONTENT_nopad=INSTANCE_CONTENT_bias.format(GemmSpec="Default")) - - Path(os.path.join(self.impl_path, f"{k.name}.cuh")).write_text( - INSTANCE_IMPL_str) - - INSTANCE_template = """// SPDX-License-Identifier: MIT + if (pad) {{ + {gemm_a8w8_device_gemm_instance_generator(k, padding)} + }} + else{{ + {gemm_a8w8_device_gemm_instance_generator(k, no_padding)} + }} +}} +""" + +def gen_a8w8_instance(k: KernelParameters, template_params: str) -> str: + return f"""// SPDX-License-Identifier: MIT // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. -#include "{name}.cuh" +#include "{k.name}.cuh" template torch::Tensor -{name}<{dtypes}>( +{k.name}<{template_params}>( torch::Tensor &XQ, torch::Tensor &WQ, torch::Tensor &x_scale, @@ -153,73 +237,49 @@ def gen_instance(self, k: kernelInstance): torch::Tensor &Y, std::optional bias, int KBatch); - """ - INSTANCE_dBF16_eBF16 = INSTANCE_template.format( - name=k.name, dtypes="B16") - INSTANCE_dFP32_eBF16 = INSTANCE_template.format( - name=k.name, dtypes="F32, B16") - INSTANCE_dFP16_eFP16 = INSTANCE_template.format( - name=k.name, dtypes="F16") - INSTANCE_dFP32_eFP16 = INSTANCE_template.format( - name=k.name, dtypes="F32, F16") - - if self.istune: - Path(os.path.join(self.instances_path, f"{k.name}_dBF16_eBF16.cpp")).write_text( - INSTANCE_dBF16_eBF16) - else: - Path(os.path.join(self.instances_path, f"{k.name}_dBF16_eBF16.cpp")).write_text( - INSTANCE_dBF16_eBF16) - Path(os.path.join(self.instances_path, f"{k.name}_dFP32_eBF16.cpp")).write_text( - INSTANCE_dFP32_eBF16) - Path(os.path.join(self.instances_path, f"{k.name}_dFP16_eFP16.cpp")).write_text( - INSTANCE_dFP16_eFP16) - Path(os.path.join(self.instances_path, f"{k.name}_dFP32_eFP16.cpp")).write_text( - INSTANCE_dFP32_eFP16) - - def gen_lookup_dict(self, kernels_dict): - LOOKUP_head = """#pragma once -// SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. -#ifdef USE_ROCM - -#define GENERATE_LOOKUP_TABLE(DTYPE, ETYPE) \\ - { \\""" - - LOOKUP_template = """ - {{{MNK}, \\ - {kernel_name}}}, \\""" - - LOOKUP_end = """ - } - -#endif // USE_ROCM -""" - with open(os.path.join(self.working_path, "gemm_a8w8_lookup.h"), "w") as f: - f.write(LOOKUP_head) - for mnk, k in kernels_dict.items(): - # print((", ").join(map(lambda x: str(x), list(mnk))), ":", k.name) - if not self.istune and (isinstance(mnk, tuple) and mnk[0] > 0): - f.write(LOOKUP_template.format(MNK="{"+(", ").join( - map(lambda x: str(x), list(mnk))) + "}", kernel_name=k.name)) - elif self.istune and isinstance(mnk, int): - f.write(LOOKUP_template.format(MNK=mnk, kernel_name=k.name)) - f.write(LOOKUP_end) - - def gen_manifest_head(self, kernels_dict): - MAINFEST_head = """#pragma once +def gen_kernel_dict_item_as_str(mnk: tuple | int, k: KernelParameters) -> str: + if isinstance(mnk, tuple): + mnk_formatted = ', '.join(str(item) for item in mnk) + else: + mnk_formatted = f"{str(mnk)}" + return f"{{ {{{mnk_formatted}}}, {k.name}}}" + +def gen_lookup_dict(kernel_dict: dict, is_tune: bool) -> str: + # Do not include default kernels in the lookup table for non-tuning calls. + filter_mnk = lambda mnk: True if is_tune else isinstance(mnk, tuple) + kernel_dict_items = [ + gen_kernel_dict_item_as_str(mnk, k) + for mnk, k in kernel_dict.items() + if filter_mnk(mnk) + ] + + lookup_table = ",\\\n ".join(kernel_dict_items) + + return f"""#pragma once // SPDX-License-Identifier: MIT // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #ifdef USE_ROCM - -#include - -#include +#define GENERATE_LOOKUP_TABLE(A_TYPE, B_TYPE, C_SHUFFLE_TYPE, COMPUTE_TYPE, ACC_TYPE, D_TYPE, E_TYPE) \\ +{{ \\ + {lookup_table} \\ +}} +#endif """ - MAINFEST_template = """ -template + +def gen_kernel_definition(kernel_name: str) -> str: + return f""" +template < + typename ADataType, + typename BDataType, + typename CShuffleDataType, + typename ComputeDataType, + typename AccDataType, + typename DDataType, + typename EDataType +> torch::Tensor {kernel_name}( torch::Tensor &XQ, @@ -229,43 +289,74 @@ def gen_manifest_head(self, kernels_dict): torch::Tensor &Y, std::optional bias, int KBatch); -""" - MAINFEST_end = """ + """ -#endif // USE_ROCM -""" +def gen_manifest(kernels_dict: dict) -> str: + kernel_definition_list = [ + gen_kernel_definition(k.name) for k in kernels_dict.values() + ] + kernel_definitions = "\n".join(kernel_definition_list) + + return f"""#pragma once +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - with open(os.path.join(self.working_path, "gemm_a8w8_manifest.h"), "w") as f: - f.write(MAINFEST_head) - for mnk, k in kernels_dict.items(): - f.write(MAINFEST_template.format(kernel_name=k.name)) - f.write(MAINFEST_end) +#ifdef USE_ROCM - def gen_instances(self, kernels_dict): - if os.path.exists(self.impl_path): - shutil.rmtree(self.impl_path) - os.mkdir(self.impl_path) - if os.path.exists(self.instances_path): - shutil.rmtree(self.instances_path) - os.mkdir(self.instances_path) +#include - for mnk, k in kernels_dict.items(): - self.gen_instance(k) +#include +{kernel_definitions} +#endif +""" - self.gen_lookup_dict(kernels_dict) - self.gen_manifest_head(kernels_dict) +def gemm_a8w8_fwd_codegen(working_path: Path, kernel_parameters_dict: dict, is_tune: bool): + impl_directory = Path(os.path.join(working_path, "impl")) + instance_directory = Path(os.path.join(working_path, "instances")) + impl_directory.mkdir(exist_ok=True) + instance_directory.mkdir(exist_ok=True) + # generate and write the implementation files + for _, kernel_parameters in kernel_parameters_dict.items(): + impl_path = Path(os.path.join(impl_directory, f"{kernel_parameters.name}.cuh")) + kernels_impl_str = gen_a8w8_implementation(kernel_parameters, is_tune) + impl_path.write_text(kernels_impl_str) + + + # generate and write the implementation files for each supported specialization + for _, kernel_parameters in kernel_parameters_dict.items(): + param_instance_list = TUNING_PARAM_INSTANCE_SUFFIX_LIST if is_tune else PARAM_INSTANCE_SUFFIX_LIST + for param_instance in param_instance_list: + template_params = param_instance[TEMPLATE_PARAMS] + instance_file_suffix = param_instance[INSTANCE_FILE_SUFFIX] + instance_file_name = f"{kernel_parameters.name}_{instance_file_suffix.lower()}" + instance_path = Path(os.path.join(instance_directory, instance_file_name)) + kernel_instance_str = gen_a8w8_instance(kernel_parameters, template_params) + instance_path.write_text(kernel_instance_str) + + # generate and write the lookup table + look_up_dict_str = gen_lookup_dict(kernel_parameters_dict, is_tune) + look_up_table_header_path = Path(os.path.join(working_path, LOOK_UP_TABLE_HEADER_PATH)) + look_up_table_header_path.write_text(look_up_dict_str) + + # generate and write the manifest + manifest_str = gen_manifest(kernel_parameters_dict) + manifest_header_path = Path(os.path.join(working_path, MANIFEST_HEADER_PATH)) + manifest_header_path.write_text(manifest_str) + +def get_tune_dict(tune_dict_csv: Path) -> dict: + if not os.path.exists(tune_dict_csv): + return default_kernels_dict + + tune_dict = default_kernels_dict + tune_df = pd.read_csv(tune_dict_csv) + for i in range(len(tune_df)): + M = tune_df.loc[i, "M"] + N = tune_df.loc[i, "N"] + K = tune_df.loc[i, "K"] + kid = tune_df.loc[i, "kernelId"] + tune_dict[(M, N, K)] = kernels_params_dict[kid] -def get_tune_dict(tune_dict_csv): - tune_dict = default_kernels_dict - if os.path.exists(tune_dict_csv): - tune_df = pd.read_csv(tune_dict_csv) - for i in range(len(tune_df)): - M = tune_df.loc[i, "M"] - N = tune_df.loc[i, "N"] - K = tune_df.loc[i, "K"] - kid = tune_df.loc[i, "kernelId"] - tune_dict[(M, N, K)] = kernels_list[kid] return tune_dict if __name__ == "__main__": @@ -298,29 +389,10 @@ def get_tune_dict(tune_dict_csv): help="generated tune instanses" ) - # parser.add_argument( - # "--out_type", - # default="all", - # required=False, - # help="Specifie the type of scale\n \ - # all: [bf16, fp16] \n \ - # bf16, fp16" - # ) - - # parser.add_argument( - # "--scale_type", - # default="all", - # required=False, - # help="Specifie the type of scale\n \ - # all: [fp32, same as out] \n \ - # same: [same as out]" - # ) - - args = parser.parse_args() - codegen = gemm_a8w8_fwd_codegen(args.working_path, args.tune) + if args.tune: - codegen.gen_instances(kernels_list) + gemm_a8w8_fwd_codegen(args.working_path, kernels_params_dict, args.tune) else: - codegen.gen_instances(get_tune_dict(args.tune_file)) + gemm_a8w8_fwd_codegen(args.working_path, get_tune_dict(args.tune_file), args.tune) diff --git a/csrc/ck_gemm_a8w8/include/gemm_a8w8_common.cuh b/csrc/ck_gemm_a8w8/include/gemm_a8w8_common.cuh index a547dafb..5ee83ca1 100644 --- a/csrc/ck_gemm_a8w8/include/gemm_a8w8_common.cuh +++ b/csrc/ck_gemm_a8w8/include/gemm_a8w8_common.cuh @@ -46,12 +46,6 @@ using F32 = float; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; -using ADataType = I8; -using BDataType = I8; -using AccDataType = I32; -using CShuffleDataType = I32; -using ComputeDataType = I8; - using ALayout = Row; using BLayout = Col; using D0Layout = Row; @@ -66,6 +60,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; using AElementOp = PassThrough; using BElementOp = PassThrough; +template struct RowwiseScale { template @@ -149,6 +144,49 @@ struct MultiplyMultiplyAdd e = ck::type_convert(x0_f); } + template <> + __host__ __device__ constexpr void operator()(F16 &e, + const float& c, + const float& d0, + const float& d1, + const F16 &d2) const + { + const float x0_f = c * d0 * d1 + ck::type_convert(d2); + + e = ck::type_convert(x0_f); + } + template <> + __host__ __device__ constexpr void operator()(F16 &e, + const float& c, + const F16& d0, + const F16& d1, + const F16 &d2) const + { + const float x0_f = c * ck::type_convert(d0) * ck::type_convert(d1) + ck::type_convert(d2); + e = ck::type_convert(x0_f); + } + + template <> + __host__ __device__ constexpr void operator()(B16 &e, + const float& c, + const B16& d0, + const B16& d1, + const B16 &d2) const + { + const float x0_f = c * ck::type_convert(d0) * ck::type_convert(d1) + ck::type_convert(d2); + e = ck::type_convert(x0_f); + } + template <> + __host__ __device__ constexpr void operator()(B16 &e, + const float& c, + const float& d0, + const float& d1, + const B16 &d2) const + { + const float x0_f = c * d0 * d1 + ck::type_convert(d2); + e = ck::type_convert(x0_f); + } + template <> __host__ __device__ constexpr void operator()( ck::bhalf_t &e, const int &c, const ck::bhalf_t &d0, const ck::bhalf_t &d1, const ck::bhalf_t &d2) const @@ -160,7 +198,6 @@ struct MultiplyMultiplyAdd } }; -using CDEElementOp = RowwiseScale; using CDEElementOp2 = MultiplyMultiplyAdd; template @@ -169,23 +206,15 @@ using DsDataType = ck::Tuple; template using DsDataType2 = ck::Tuple; -#if 0 -template -using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 -// clang-format off -///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| -///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| -///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S| -///###### RRR -/// < Row, Row, DsLayout, ELayout, ADataType, BDataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, I8>; -///###### RCR - < Row, Col, DsLayout, ELayout, ADataType, BDataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, I8>; -// clang-format on -#endif template < - typename DDataType, typename EDataType, + typename ADataType, + typename BDataType, + typename AccDataType, + typename CShuffleDataType, + typename ComputeDataType, + typename DDataType, + typename EDataType, int BLOCK_SIZE, int MBLOCK, int NBLOCK, @@ -218,7 +247,7 @@ using DeviceGemmHelper = CShuffleDataType, AElementOp, BElementOp, - CDEElementOp, + RowwiseScale, GEMM_SPEC, BLOCK_SIZE, // Block Size MBLOCK, // M per Block @@ -253,7 +282,13 @@ using DeviceGemmHelper = ComputeDataType>; template < - typename DDataType, typename EDataType, + typename ADataType, + typename BDataType, + typename AccDataType, + typename CShuffleDataType, + typename ComputeDataType, + typename DDataType, + typename EDataType, int BLOCK_SIZE, int MBLOCK, int NBLOCK, @@ -320,8 +355,14 @@ using DeviceGemmHelperMMA = PIPELINE_VERSION, ComputeDataType>; -template +template < + typename ADataType, + typename BDataType, + typename AccDataType, + typename DDataType, + typename EDataType, + typename DeviceGemmInstance +> __forceinline__ torch::Tensor gemm_a8w8_rowwise_impl( torch::Tensor &XQ, torch::Tensor &WQ, @@ -345,7 +386,7 @@ __forceinline__ torch::Tensor gemm_a8w8_rowwise_impl( auto a_element_op = AElementOp{}; auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{}; + auto cde_element_op = RowwiseScale{}; constexpr ck::index_t NumDTensor = DeviceGemmInstance::NumDTensor; @@ -373,8 +414,13 @@ __forceinline__ torch::Tensor gemm_a8w8_rowwise_impl( return Y; } -template +template < + typename ADataType, + typename BDataType, + typename DDataType, + typename EDataType, + typename DeviceGemmInstance +> __forceinline__ torch::Tensor gemm_a8w8_mma_impl( torch::Tensor &XQ, torch::Tensor &WQ, diff --git a/op_tests/__init__.py b/op_tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/op_tests/conftest.py b/op_tests/conftest.py new file mode 100644 index 00000000..85765f23 --- /dev/null +++ b/op_tests/conftest.py @@ -0,0 +1,292 @@ +import pytest +from collections import defaultdict +import os +import numpy as np +from prettytable import PrettyTable +import torch +import torch.profiler as tpf +from typing import Callable, Generator, TypeVar +from functools import partial + +T = TypeVar("T") + +AITER_LOG_MORE = int(os.environ.get('AITER_LOG_MORE', 0)) +NUM_ITERATIONS = int(os.environ.get('NUM_ITERATIONS', 100)) +NUM_WARMUP_ITERATION = os.environ.get('NUM_WARMUP_ITERATION', 10) +USE_CUDA_GRAPH = os.environ.get('USE_CUDA_GRAPH', 0) + +def execute_callback( + num_iterations: int, + func: Callable[..., T], + *args, + **kwargs +) -> T: + """ + Executes a callback for a given number of iterations. + + Parameters + ---------- + num_iterations : int. + The number of iterations to use for profiling the callback. + func : Callable[..., T]. + A callback function with arbitrary arguments to be executed. + *args + Variable length argument list for the callback function. + **kwargs + Keyword arguments for the callback function. + + Returns + ------- + T: The last value returned by the callback function. + + """ + for _ in range(num_iterations): + result = func(*args, **kwargs) + return result + +def time_callback_with_cuda_event( + num_iterations: int, + func: Callable[..., T], + *args, + **kwargs +) -> list[float]: + """ + Measure the average execution time of a given function using CUDA + events in milliseconds. + + Parameters + ---------- + num_iterations : int. + The number of iterations to use for profiling the callback. + func : Callable[..., T]. + A callback function with arbitrary arguments to be executed. + *args + Variable length argument list for the callback function. + **kwargs + Keyword arguments for the callback function. + Returns + ------- + list[float]: The executions times in milliseconds. + """ + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + latency_list = [] + for _ in range(num_iterations): + start_event.record() + func(*args, **kwargs) + end_event.record() + end_event.synchronize() + latency = start_event.elapsed_time(end_event) + latency_list.append(latency) + return latency_list + +def extract_cuda_time_trace(torch_profiler: torch.profiler) -> list[float]: + """ + Extract the CUDA times from a PyTorch profiler trace. + + This function extracts self device time for all CUDA events + in the profiler trace, excluding the first event. + + Parameters + ---------- + torch_profiler : torch.profiler + A PyTorch profiler object containing trace events. + + Returns + ------- + float: list[float]: The executions times in milliseconds. + + """ + get_cuda_total_time = lambda event: getattr(event, "self_device_time_total", 0.0) + is_cuda = lambda event: getattr(event, "device_type", None) == torch.profiler.DeviceType.CUDA + + if len(torch_profiler.events()) <=1: + return 0.0 + + return[ + get_cuda_total_time(event) / 1000 + for event in torch_profiler.events()[1:] + if is_cuda(event) + ] + +def profile( + num_iterations: int, + num_warmup_iterations:int, + func: Callable[..., T], + *args, + **kwargs +) -> tuple[list[float], T]: + """ + Profile the execution of a function using PyTorch Profiler. + + This function performs warmup iterations, then profiles the actual execution + of the function for a specified number of iterations. + + Parameters + ---------- + num_iterations : int. + The number of iterations to use for profiling the callback. + num_warmup_iterations : int. + The number of iterations to use for warmup before profiling the callback. + func : Callable[..., T]. + A callback function with arbitrary arguments to be executed. + *args + Variable length argument list for the callback function. + **kwargs + Keyword arguments for the callback function. + + Returns + ------- + tuple[list[float], T]: A tuple containing: + - list[float]: The executions times in milliseconds. + - T: The result of the last function execution. + """ + + # warmup + result = execute_callback(num_warmup_iterations, func, *args, **kwargs) + + # Profile using cuda.Event + # Note: The use of AITER_LOG_MORE variable + # is temporary (until we completly shift to using pytest) + # and should be replaced with a more descriptive + # flag. + if AITER_LOG_MORE: + latency_list = time_callback_with_cuda_event( + num_iterations, func, *args, **kwargs) + return latency_list, result + + # Profile using torch.profiler + with tpf.profile( + activities=[ + tpf.ProfilerActivity.CPU, + tpf.ProfilerActivity.CUDA + ], + profile_memory=True, + with_stack=True, + with_modules=True, + ) as prof: + execute_callback(num_iterations, func, *args, **kwargs) + + latency_list = extract_cuda_time_trace(prof) + return latency_list, result + +def profile_cuda_graph( + num_iterations: int, + num_warmup_iterations:int, + func: Callable[..., T], + *args, + **kwargs +) -> tuple[list[float], T]: + """ + Profile the execution of a function using CUDA Graph and PyTorch Profiler. + + This function creates a CUDA Graph for the given function, then profiles its + execution using the standard profile function. + + Parameters + ---------- + num_iterations : int. + The number of iterations to use for profiling the callback. + num_warmup_iterations : int. + The number of iterations to use for warmup before profiling the callback. + func : Callable[..., T]. + A callback function with arbitrary arguments to be executed. + *args + Variable length argument list for the callback function. + **kwargs + Keyword arguments for the callback function. + + Returns + ------- + tuple[list[float], T]: A tuple containing: + - list[float]: The executions times in milliseconds. + - T: The result of the last function execution. + """ + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + execute_callback(1, func, *args, **kwargs) + return profile(num_iterations, num_warmup_iterations, func, *args, **kwargs) + + +@pytest.fixture(scope="session") +def benchmark() -> Generator[Callable[..., tuple[float, T]], None, None]: + """ + A pytest fixture that provides a benchmarking function for performance testing. + + This fixture creates a generator that yields a function which can be used to benchmark + other functions, optionally using CUDA graphs. It collects statistics on execution times + and presents the results in a formatted table at the end of the test session. + + Yields + ------ + Callable[..., Tuple[float, T]]: A benchmarking function that can be called within tests. + + Usage + ----- + def test_example(benchmark): + avg_time, result = benchmark( + func=my_function_to_test, + arg1=value1, + arg2=value2 + ) + """ + test_result_data: dict[str, list[dict[str, float]]] = defaultdict(list) + + def _benchmark( + func: Callable[..., T], + *args, + **kwargs + ) -> T: + + test_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0] + profile_func = profile_cuda_graph if USE_CUDA_GRAPH else profile + + execution_time_list, result = profile_func( + NUM_ITERATIONS, + NUM_WARMUP_ITERATION, + func, + *args, + **kwargs + ) + + stats = { + "test_name": test_name, + "avg_time": np.mean(execution_time_list), + "std_time": np.std(execution_time_list), + "min_time": np.min(execution_time_list), + "max_time": np.max(execution_time_list), + "iterations": NUM_ITERATIONS, + "warmup": NUM_WARMUP_ITERATION, + "cuda_graph": USE_CUDA_GRAPH + } + test_result_data[test_name].append(stats) + + return result + + yield _benchmark + + table = PrettyTable() + table.field_names = ["Test Name", "Avg Time (ms)", "Std Dev (ms)", "Min Time (ms)", "Max Time (ms)", "Iterations", "Warmup", "CUDA Graph"] + + for test_results in test_result_data.values(): + for result in test_results: + table.add_row([ + result["test_name"], + f"{result['avg_time']:.6f}", + f"{result['std_time']:.6f}", + f"{result['min_time']:.6f}", + f"{result['max_time']:.6f}", + result["iterations"], + result["warmup"], + "Yes" if result["cuda_graph"] else "No" + ]) + + table.align = "r" # Right-align all columns + table.align["Test"] = "l" # Left-align the Test Name column + table.float_format = ".6" # Set float precision to 6 decimal places + table.title = "Benchmark Results" + table.border = True + table.header = True + + print("\n", table) + diff --git a/op_tests/test_gemm_a8w8.py b/op_tests/test_gemm_a8w8.py index a3903f5a..4b786839 100644 --- a/op_tests/test_gemm_a8w8.py +++ b/op_tests/test_gemm_a8w8.py @@ -1,18 +1,55 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - -from aiter.test_common import checkAllclose, perftest, tensor_dump -import torch -import torch.nn.functional as F -import numpy as np -import sys -import os import aiter +from aiter.test_common import assert_all_close, rand_tensor from aiter.ops.shuffle import shuffle_weight +import pytest +import torch +import torch.nn.functional as F + + +MNK = [ + # qkv_proj + (1, 1280, 8192), + (32, 1280, 8192), + (64, 1280, 8192), + (128, 1280, 8192), + (192, 1280, 8192), + (256, 1280, 8192), + (320, 1280, 8192), + (512, 1280, 8192), + (1024, 1280, 8192), + (2048, 1280, 8192), + (4096, 1280, 8192), + (8192, 1280, 8192), + (16384, 1280, 8192), + # attn_out + (1, 8192, 1024), + (32, 8192, 1024), + (64, 8192, 1024), + (128, 8192, 1024), + (192, 8192, 1024), + (256, 8192, 1024), + (320, 8192, 1024), + (512, 8192, 1024), + (1024, 8192, 1024), + (2048, 8192, 1024), + (4096, 8192, 1024), + (8192, 8192, 1024), + (16384, 8192, 1024), +] +# ck_gemm only accepts the following combination for | scale : output | dtypes +# | F32 : F16 | F32 : BF16 | F16 : F16 | B16 : B16 | +CK_GEMM_SCALES_OUTPUT_DTYPES = [ + (torch.float32, torch.float16), + (torch.float32, torch.bfloat16), + (torch.float16, torch.float16), + (torch.bfloat16, torch.bfloat16), +] -@perftest() -def run_torch(x, weight, x_scale, w_scale, bias=None, dtype=torch.bfloat16): + +def torch_scaled_mm(x, weight, x_scale, w_scale, bias=None, dtype=torch.bfloat16): x = F.linear(x.to(torch.float32), weight.to(torch.float32)) scale = torch.matmul(x_scale, w_scale) out = torch.mul(x, scale) @@ -21,77 +58,73 @@ def run_torch(x, weight, x_scale, w_scale, bias=None, dtype=torch.bfloat16): return out.to(dtype) -@perftest() -def run_gemm_ck(x, weight, x_scale, w_scale, bias=None, dtype=torch.bfloat16): - return aiter.gemm_a8w8_CK(x, weight, x_scale, w_scale, bias) +def setup_gemm_test( + mnk: tuple[int, int, int], + ab_dtype: torch.dtype, + scales_dtype: torch.dtype, + bias_dtype: torch.dtype, + use_bias: bool, +) -> tuple[torch.tensor, ...]: + + m, n, k = mnk + a = rand_tensor(size=(m, k), dtype=ab_dtype).cuda() + b = rand_tensor(size=(n, k), dtype=ab_dtype).cuda() + a_scale = rand_tensor(size=(m, 1), dtype=scales_dtype).cuda() + 1e-6 + b_scale = rand_tensor(size=(1, n), dtype=scales_dtype).cuda() + 1e-6 + bias = (rand_tensor(size=(1, n), dtype=bias_dtype).cuda()) if use_bias else None + return a, b, a_scale, b_scale, bias -@perftest() -def run_gemm_asm(x, weightshuffle, x_scale, w_scale, bias=None, dtype=torch.bfloat16): - return aiter.gemm_a8w8_ASM(x, weightshuffle, x_scale, w_scale, bias) +@pytest.mark.parametrize("mnk", MNK) +@pytest.mark.parametrize("ab_dtype", [torch.int8, torch.float8_e4m3fnuz]) +@pytest.mark.parametrize("scales_output_dtype", CK_GEMM_SCALES_OUTPUT_DTYPES) +@pytest.mark.parametrize("use_bias", [True, False]) +def test_ck_gemm_close_to_torch( + benchmark, + mnk: tuple[int, int, int], + ab_dtype: torch.dtype, + scales_output_dtype: tuple[torch.dtype, torch.dtype], + use_bias: bool +) -> None: + # using bias further reduces precision, + # so we use a higher tolerance when bias is enabled. + if use_bias: + rtol, atol = (1e-1, 1e-1) + else: + rtol, atol = (1e-2, 1e-1) + scales_dtype, out_dtype = scales_output_dtype + bias_dtype = out_dtype + a, b, a_scale, b_scale, bias = setup_gemm_test( + mnk, + ab_dtype, + scales_dtype, + bias_dtype, + use_bias, + ) -def test_gemm(dtype, m, n, k): - dim = (m, n, k) - x = torch.randint(-20, 20, (m, k), dtype=torch.int8).cuda() - weight = torch.randint(-20, 20, (n, k), dtype=torch.int8).cuda() - x_scale = torch.rand([m, 1], dtype=torch.float32).cuda() + 1e-6 - w_scale = torch.rand([1, n], dtype=torch.float32).cuda() + 1e-6 - bias = torch.rand([1, n], dtype=dtype).cuda() * 10 - weightshuffle = shuffle_weight(weight,layout=(32,16)) - bias_f32 = bias.to(torch.float) - x_pad, _ = F.pad(x,(0,128), "constant", 0).split([x.shape[1], 128],dim=1) - # print(f"{x_pad.shape=}{x_pad.stride()}") - # tensor_dump(x, 'x') - # tensor_dump(weight, 'weight') - # tensor_dump(shuffle_weight(weight), 'weight_shuffled') - # tensor_dump(x_scale, 'x_scale') - # tensor_dump(w_scale, 'w_scale') - # tensor_dump(bias, 'bias') + output = benchmark(aiter.gemm_a8w8_CK, a, b, a_scale, b_scale, bias, out_dtype) + expected = torch_scaled_mm(a, b, a_scale, b_scale, bias, dtype=out_dtype) - a, avg_a = run_torch(x, weight, x_scale, w_scale, bias, dtype) - b, avg_b = run_gemm_ck(x, weight, x_scale, w_scale, bias, dtype) - c, avg_c = run_gemm_asm(x, weightshuffle, x_scale, w_scale, bias_f32, dtype) - if c is None: - msg = f"[perf] dim: {str(dim):<20} dtype: {dtype}, torch avg: {avg_a:<8.2f} us, ck avg: {avg_b:<8.2f} us, asm : not support, uplift: {avg_a/avg_b-1:<5.1%}" - else: - msg = f"[perf] dim: {str(dim):<20} dtype: {dtype}, torch avg: {avg_a:<8.2f} us, ck avg: {avg_b:<8.2f} us, asm avg: {avg_c:<8.2f} us, uplift: {avg_a/min(avg_b,avg_c)-1:<5.1%}" - checkAllclose(a, b, msg="a,b: "+msg, rtol=1e-2, atol=0.01) - if c != None: - checkAllclose(a, c, msg="\033[1A\033[2K" + "a,c: "+ msg, rtol=1e-2, atol=0.01) + assert_all_close(output, expected, rtol=rtol, atol=atol) +@pytest.mark.parametrize("mnk", MNK) +def test_asm_gemm_close_to_torch(benchmark, mnk: tuple[int, int, int]) -> None: + rtol, atol = (1e-1, 1e-1) + ab_dtype = torch.int8 + out_dtype = torch.bfloat16 + scales_dtype = torch.float32 + bias_dtype = torch.float + # gemm_a8w8_ASM requires bias and shuffle + a, b, a_scale, b_scale, bias = setup_gemm_test( + mnk, + ab_dtype, + scales_dtype, + bias_dtype, + use_bias=True, + ) + b_shuffled= shuffle_weight(b, layout=(32, 16)) -for dtype in [torch.bfloat16]: - # qkv_proj - for (m, n, k) in [ - (1, 1280, 8192), - (32, 1280, 8192), - (64, 1280, 8192), - (128, 1280, 8192), - (192, 1280, 8192), - (256, 1280, 8192), - (320, 1280, 8192), - (512, 1280, 8192), - (1024, 1280, 8192), - (2048, 1280, 8192), - (4096, 1280, 8192), - (8192, 1280, 8192), - (16384, 1280, 8192), - ]: - test_gemm(dtype, m, n, k) - # attn_out - for (m, n, k) in [ - (1, 8192, 1024), - (32, 8192, 1024), - (64, 8192, 1024), - (128, 8192, 1024), - (192, 8192, 1024), - (256, 8192, 1024), - (320, 8192, 1024), - (512, 8192, 1024), - (1024, 8192, 1024), - (2048, 8192, 1024), - (4096, 8192, 1024), - (8192, 8192, 1024), - (16384, 8192, 1024), - ]: - test_gemm(dtype, m, n, k) + output = benchmark(aiter.gemm_a8w8_ASM, a, b_shuffled, a_scale, b_scale, bias) + expected = torch_scaled_mm(a, b, a_scale, b_scale, bias, dtype=out_dtype) + if output is not None and torch.sum(output.isnan()==True) == 0: + assert_all_close(output, expected, rtol=rtol, atol=atol) diff --git a/requirements.txt b/requirements.txt index 9b80d9dc..59ac69c9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ pandas<=2.0.3 pytest +prettytable==3.14.* numpy<2.0.0 psutil einops