|
| 1 | +# Codegen pipeline |
| 2 | + |
| 3 | +## hlo-opt |
| 4 | + |
| 5 | +This pass pipeline is mainly used for clustering fusion group on mhlo dialect, each fusion group was expected to fused into a single kernel in later codegen pipeline and would be outlined as a indepedent kernel function. |
| 6 | + |
| 7 | +- `ReductionFusionPass` reduction fusion in producer direction |
| 8 | + |
| 9 | +- `ElementFusionPass` elementwise/broadcast/collapse_shape/expand_shape/etc. producer-consumer bi-directional fusion |
| 10 | + |
| 11 | +- `FusionOutliningPass` fusion group outlining |
| 12 | + |
| 13 | +## linalg-tensor-opt |
| 14 | + |
| 15 | +### reduction codegen transformations |
| 16 | + |
| 17 | +``` |
| 18 | + func.func private @Unknown0(%arg0: tensor<8192x50257xf16>) -> tensor<50257xf32> attributes {__byteir_reduction_fusion__} { |
| 19 | + %0 = mhlo.constant dense<0.000000e+00> : tensor<f32> |
| 20 | + %1 = mhlo.convert %arg0 : (tensor<8192x50257xf16>) -> tensor<8192x50257xf32> |
| 21 | + %2 = mhlo.reduce(%1 init: %0) across dimensions = [0] : (tensor<8192x50257xf32>, tensor<f32>) -> tensor<50257xf32> |
| 22 | + reducer(%arg1: tensor<f32>, %arg2: tensor<f32>) { |
| 23 | + %3 = mhlo.add %arg1, %arg2 : tensor<f32> |
| 24 | + mhlo.return %3 : tensor<f32> |
| 25 | + } |
| 26 | + return %2 : tensor<50257xf32> |
| 27 | + } |
| 28 | +``` |
| 29 | + |
| 30 | +This pass pipeline first convert outlined mhlo fusion group into linalg dialect and try to fuse linalg op with its producer/consumer. |
| 31 | + |
| 32 | +- `createLinalgElementwiseFusionExtPass` linalg fusion pass with our extension, see [linalg pass](linalg.md) for more details |
| 33 | + |
| 34 | +``` |
| 35 | +func.func private @Unknown0(%arg0: tensor<8192x50257xf16>) -> tensor<50257xf32> attributes {__byteir_reduction_fusion__} { |
| 36 | + %cst = arith.constant 0.000000e+00 : f32 |
| 37 | + %0 = tensor.empty() : tensor<50257xf32> |
| 38 | + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<50257xf32>) -> tensor<50257xf32> |
| 39 | + %2 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<8192x50257xf16>) outs(%1 : tensor<50257xf32>) { |
| 40 | + ^bb0(%in: f16, %out: f32): |
| 41 | + %3 = arith.extf %in : f16 to f32 |
| 42 | + %4 = arith.addf %out, %3 : f32 |
| 43 | + linalg.yield %4 : f32 |
| 44 | + } -> tensor<50257xf32> |
| 45 | + return %2 : tensor<50257xf32> |
| 46 | +} |
| 47 | +``` |
| 48 | + |
| 49 | +[optional] Split grid-level reduction on `reduction` dimensions |
| 50 | + |
| 51 | +``` |
| 52 | +func.func private @Unknown0(%arg0: tensor<8192x50257xf16>) -> tensor<50257xf32> attributes {__byteir_reduction_fusion__} { |
| 53 | + %cst = arith.constant 0.000000e+00 : f32 |
| 54 | + %0 = tensor.empty() : tensor<50257xf32> |
| 55 | + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<50257xf32>) -> tensor<50257xf32> |
| 56 | + %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] : tensor<8192x50257xf16> into tensor<32x256x50257xf16> |
| 57 | + %2 = tensor.empty() : tensor<32x50257xf32> |
| 58 | + %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<32x50257xf32>) -> tensor<32x50257xf32> |
| 59 | + %4 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%expanded : tensor<32x256x50257xf16>) outs(%3 : tensor<32x50257xf32>) attrs = {__grid_reduction__} { |
| 60 | + ^bb0(%in: f16, %out: f32): |
| 61 | + %6 = arith.extf %in : f16 to f32 |
| 62 | + %7 = arith.addf %out, %6 : f32 |
| 63 | + linalg.yield %7 : f32 |
| 64 | + } -> tensor<32x50257xf32> |
| 65 | + %5 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["reduction", "parallel"]} ins(%4 : tensor<32x50257xf32>) outs(%1 : tensor<50257xf32>) attrs = {__grid_reduction__} { |
| 66 | + ^bb0(%in: f32, %out: f32): |
| 67 | + %6 = arith.addf %in, %out : f32 |
| 68 | + linalg.yield %6 : f32 |
| 69 | + } -> tensor<50257xf32> |
| 70 | + return %5 : tensor<50257xf32> |
| 71 | +} |
| 72 | +``` |
| 73 | + |
| 74 | +- Tiling reduction on `parallel` dimension and mapping tiled reductions to thread blocks |
| 75 | + |
| 76 | +``` |
| 77 | +func.func private @Unknown0(%arg0: tensor<8192x50257xf16>) -> tensor<50257xf32> attributes {__byteir_reduction_fusion__} { |
| 78 | + %cst = arith.constant 0.000000e+00 : f32 |
| 79 | + %0 = tensor.empty() : tensor<50257xf32> |
| 80 | + %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] : tensor<8192x50257xf16> into tensor<32x256x50257xf16> |
| 81 | + %1 = tensor.empty() : tensor<32x50257xf32> |
| 82 | + %2 = scf.forall (%arg1, %arg2) in (32, 1571) shared_outs(%arg3 = %1) -> (tensor<32x50257xf32>) { |
| 83 | + %4 = affine.min #map(%arg2) |
| 84 | + %5 = affine.apply #map1(%arg2) |
| 85 | + %extracted_slice = tensor.extract_slice %expanded[%arg1, 0, %5] [1, 256, %4] [1, 1, 1] : tensor<32x256x50257xf16> to tensor<256x?xf16> |
| 86 | + %extracted_slice_0 = tensor.extract_slice %arg3[%arg1, %5] [1, %4] [1, 1] : tensor<32x50257xf32> to tensor<?xf32> |
| 87 | + %6 = linalg.fill ins(%cst : f32) outs(%extracted_slice_0 : tensor<?xf32>) -> tensor<?xf32> |
| 88 | + %7 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice : tensor<256x?xf16>) outs(%6 : tensor<?xf32>) { |
| 89 | + ^bb0(%in: f16, %out: f32): |
| 90 | + %8 = arith.extf %in : f16 to f32 |
| 91 | + %9 = arith.addf %out, %8 : f32 |
| 92 | + linalg.yield %9 : f32 |
| 93 | + } -> tensor<?xf32> |
| 94 | + scf.forall.in_parallel { |
| 95 | + tensor.parallel_insert_slice %7 into %arg3[%arg1, %5] [1, %4] [1, 1] : tensor<?xf32> into tensor<32x50257xf32> |
| 96 | + } |
| 97 | + } {mapping = [#gpu.block<y>, #gpu.block<x>]} |
| 98 | + %3 = scf.forall (%arg1) in (1571) shared_outs(%arg2 = %0) -> (tensor<50257xf32>) { |
| 99 | + // ... |
| 100 | + } {mapping = [#gpu.block<x>]} |
| 101 | + return %3 : tensor<50257xf32> |
| 102 | +} |
| 103 | +``` |
| 104 | + |
| 105 | +- Block-level reduction codegen |
| 106 | + |
| 107 | +``` |
| 108 | +%2 = scf.forall (%arg1, %arg2) in (32, 1571) shared_outs(%arg3 = %1) -> (tensor<32x50257xf32>) { |
| 109 | + %4 = affine.min #map(%arg2) |
| 110 | + %5 = affine.apply #map1(%arg2) |
| 111 | + %extracted_slice = tensor.extract_slice %expanded[%arg1, 0, %5] [1, 256, %4] [1, 1, 1] : tensor<32x256x50257xf16> to tensor<256x?xf16> |
| 112 | + %6 = bufferization.alloc_tensor() {memory_space = #gpu.address_space<workgroup>} : tensor<32xf32> |
| 113 | + %7 = bufferization.alloc_tensor() {memory_space = #gpu.address_space<workgroup>} : tensor<16x32xf32> |
| 114 | + %8 = scf.forall (%arg4, %arg5) in (16, 32) shared_outs(%arg6 = %7) -> (tensor<16x32xf32>) { |
| 115 | + %17 = affine.min #map2(%arg4) |
| 116 | + %18 = affine.min #map3(%arg4) |
| 117 | + %19 = affine.apply #map4(%18, %17) |
| 118 | + %20 = affine.min #map5(%arg5, %arg2) |
| 119 | + %21 = affine.min #map6(%arg5, %arg2) |
| 120 | + %22 = affine.apply #map4(%21, %20) |
| 121 | + %23 = affine.apply #map7(%21, %20) |
| 122 | + %extracted_slice_6 = tensor.extract_slice %extracted_slice[%17, %20] [%19, %22] [1, 1] : tensor<256x?xf16> to tensor<?x?xf16> |
| 123 | + %padded = tensor.pad %extracted_slice_6 low[0, 0] high[0, %23] { |
| 124 | + ^bb0(%arg7: index, %arg8: index): |
| 125 | + tensor.yield %cst : f16 |
| 126 | + } : tensor<?x?xf16> to tensor<16x1xf16> |
| 127 | + %extracted_slice_7 = tensor.extract_slice %arg6[%arg4, %arg5] [1, 1] [1, 1] : tensor<16x32xf32> to tensor<f32> |
| 128 | + %collapsed = tensor.collapse_shape %padded [[0, 1]] : tensor<16x1xf16> into tensor<16xf16> |
| 129 | + %24 = linalg.fill ins(%cst_0 : f32) outs(%extracted_slice_7 : tensor<f32>) -> tensor<f32> |
| 130 | + %25 = linalg.generic {indexing_maps = [#map8, #map9], iterator_types = ["reduction"]} ins(%collapsed : tensor<16xf16>) outs(%24 : tensor<f32>) { |
| 131 | + ^bb0(%in: f16, %out: f32): |
| 132 | + %26 = arith.extf %in : f16 to f32 |
| 133 | + %27 = arith.addf %out, %26 : f32 |
| 134 | + linalg.yield %27 : f32 |
| 135 | + } -> tensor<f32> |
| 136 | + scf.forall.in_parallel { |
| 137 | + tensor.parallel_insert_slice %25 into %arg6[%arg4, %arg5] [1, 1] [1, 1] : tensor<f32> into tensor<16x32xf32> |
| 138 | + } |
| 139 | + } {mapping = [#gpu.thread<y>, #gpu.thread<x>]} |
| 140 | + %expanded_1 = tensor.expand_shape %8 [[0, 1], [2]] : tensor<16x32xf32> into tensor<8x2x32xf32> |
| 141 | + %9 = bufferization.alloc_tensor() {memory_space = #gpu.address_space<workgroup>} : tensor<8x32xf32> |
| 142 | + %10 = scf.forall (%arg4, %arg5) in (8, 32) shared_outs(%arg6 = %9) -> (tensor<8x32xf32>) { |
| 143 | + // ... |
| 144 | + } {mapping = [#gpu.thread<y>, #gpu.thread<x>]} |
| 145 | + %expanded_2 = tensor.expand_shape %10 [[0, 1], [2]] : tensor<8x32xf32> into tensor<4x2x32xf32> |
| 146 | + %11 = bufferization.alloc_tensor() {memory_space = #gpu.address_space<workgroup>} : tensor<4x32xf32> |
| 147 | + %12 = scf.forall (%arg4, %arg5) in (4, 32) shared_outs(%arg6 = %11) -> (tensor<4x32xf32>) { |
| 148 | + // ... |
| 149 | + } {mapping = [#gpu.thread<y>, #gpu.thread<x>]} |
| 150 | + %expanded_3 = tensor.expand_shape %12 [[0, 1], [2]] : tensor<4x32xf32> into tensor<2x2x32xf32> |
| 151 | + %13 = bufferization.alloc_tensor() {memory_space = #gpu.address_space<workgroup>} : tensor<2x32xf32> |
| 152 | + %14 = scf.forall (%arg4, %arg5) in (2, 32) shared_outs(%arg6 = %13) -> (tensor<2x32xf32>) { |
| 153 | + // ... |
| 154 | + } {mapping = [#gpu.thread<y>, #gpu.thread<x>]} |
| 155 | + %15 = scf.forall (%arg4) in (32) shared_outs(%arg5 = %6) -> (tensor<32xf32>) { |
| 156 | + // ... |
| 157 | + } {mapping = [#gpu.thread<x>]} |
| 158 | + %extracted_slice_4 = tensor.extract_slice %15[0] [%4] [1] : tensor<32xf32> to tensor<?xf32> |
| 159 | + %extracted_slice_5 = tensor.extract_slice %arg3[%arg1, %5] [1, %4] [1, 1] : tensor<32x50257xf32> to tensor<?xf32> |
| 160 | + %16 = scf.forall (%arg4) in (512) shared_outs(%arg5 = %extracted_slice_5) -> (tensor<?xf32>) { |
| 161 | + // ... |
| 162 | + } {mapping = [#gpu.linear<x>]} |
| 163 | + scf.forall.in_parallel { |
| 164 | + tensor.parallel_insert_slice %16 into %arg3[%arg1, %5] [1, %4] [1, 1] : tensor<?xf32> into tensor<32x50257xf32> |
| 165 | + } |
| 166 | +} {mapping = [#gpu.block<y>, #gpu.block<x>]} |
| 167 | +``` |
| 168 | + |
| 169 | +- Detensorize scalar linalg ops to arith ops and specialize `tensor.pad` |
| 170 | + |
| 171 | +``` |
| 172 | +%2 = scf.forall (%arg1, %arg2) in (32, 1571) shared_outs(%arg3 = %1) -> (tensor<32x50257xf32>) { |
| 173 | + %4 = affine.min #map(%arg2) |
| 174 | + %5 = affine.apply #map1(%arg2) |
| 175 | + %6 = bufferization.alloc_tensor() {memory_space = #gpu.address_space<workgroup>} : tensor<32xf32> |
| 176 | + %7 = bufferization.alloc_tensor() {memory_space = #gpu.address_space<workgroup>} : tensor<16x32xf32> |
| 177 | + %8 = scf.forall (%arg4, %arg5) in (16, 32) shared_outs(%arg6 = %7) -> (tensor<16x32xf32>) { |
| 178 | + %17 = affine.min #map2(%arg5, %arg2) |
| 179 | + %18 = affine.min #map3(%arg5, %arg2) |
| 180 | + %19 = affine.apply #map4(%18, %17) |
| 181 | + %20 = arith.cmpi ugt, %19, %c0 : index |
| 182 | + %21 = scf.if %20 -> (f16) { |
| 183 | + %84 = affine.apply #map5(%arg4) |
| 184 | + %85 = affine.apply #map6(%arg2)[%17] |
| 185 | + %extracted = tensor.extract %expanded[%arg1, %84, %85] : tensor<32x256x50257xf16> |
| 186 | + scf.yield %extracted : f16 |
| 187 | + } else { |
| 188 | + scf.yield %cst : f16 |
| 189 | + } |
| 190 | + // ... |
| 191 | + %78 = arith.extf %77 : f16 to f32 |
| 192 | + %79 = arith.addf %75, %78 : f32 |
| 193 | + %80 = arith.cmpi ugt, %19, %c0 : index |
| 194 | + %81 = scf.if %80 -> (f16) { |
| 195 | + %84 = affine.apply #map21(%arg4) |
| 196 | + %85 = affine.apply #map6(%arg2)[%17] |
| 197 | + %extracted = tensor.extract %expanded[%arg1, %84, %85] : tensor<32x256x50257xf16> |
| 198 | + scf.yield %extracted : f16 |
| 199 | + } else { |
| 200 | + scf.yield %cst : f16 |
| 201 | + } |
| 202 | + %82 = arith.extf %81 : f16 to f32 |
| 203 | + %83 = arith.addf %79, %82 : f32 |
| 204 | + %extracted_slice_5 = tensor.extract_slice %arg6[%arg4, %arg5] [1, 1] [1, 1] : tensor<16x32xf32> to tensor<f32> |
| 205 | + %inserted = tensor.insert %83 into %extracted_slice_5[] : tensor<f32> |
| 206 | + scf.forall.in_parallel { |
| 207 | + tensor.parallel_insert_slice %inserted into %arg6[%arg4, %arg5] [1, 1] [1, 1] : tensor<f32> into tensor<16x32xf32> |
| 208 | + } |
| 209 | + } {mapping = [#gpu.thread<y>, #gpu.thread<x>]} |
| 210 | + |
| 211 | + // ... |
| 212 | + %extracted_slice = tensor.extract_slice %15[0] [%4] [1] : tensor<32xf32> to tensor<?xf32> |
| 213 | + %extracted_slice_4 = tensor.extract_slice %arg3[%arg1, %5] [1, %4] [1, 1] : tensor<32x50257xf32> to tensor<?xf32> |
| 214 | + %16 = scf.forall (%arg4) in (512) shared_outs(%arg5 = %extracted_slice_4) -> (tensor<?xf32>) { |
| 215 | + %17 = affine.min #map22(%arg4)[%4] |
| 216 | + %18 = affine.max #map23(%17) |
| 217 | + %19 = affine.apply #map24(%arg4)[%4] |
| 218 | + %extracted_slice_5 = tensor.extract_slice %extracted_slice[%19] [%18] [1] : tensor<?xf32> to tensor<?xf32> |
| 219 | + %extracted_slice_6 = tensor.extract_slice %arg5[%19] [%18] [1] : tensor<?xf32> to tensor<?xf32> |
| 220 | + %20 = linalg.copy {__byteir_gpu_tile_block_reduction_10} ins(%extracted_slice_5 : tensor<?xf32>) outs(%extracted_slice_6 : tensor<?xf32>) -> tensor<?xf32> |
| 221 | + scf.forall.in_parallel { |
| 222 | + tensor.parallel_insert_slice %20 into %arg5[%19] [%18] [1] : tensor<?xf32> into tensor<?xf32> |
| 223 | + } |
| 224 | + } {mapping = [#gpu.linear<x>]} |
| 225 | + scf.forall.in_parallel { |
| 226 | + tensor.parallel_insert_slice %16 into %arg3[%arg1, %5] [1, %4] [1, 1] : tensor<?xf32> into tensor<32x50257xf32> |
| 227 | + } |
| 228 | +} {mapping = [#gpu.block<y>, #gpu.block<x>]} |
| 229 | +``` |
| 230 | + |
| 231 | +- `structured.split_reduction` split reduction op along `reduction` dimension for increasing parallelism |
| 232 | + |
| 233 | +- `structured.tile_to_forall_op` tile reduction op along `parallel` dimensions to `forall` op and mapping to block/linear/thread |
| 234 | + |
| 235 | +- `structured.fuse_into_containing_op` fuse init and pad operands into `scf.forall` |
| 236 | + |
| 237 | +- `structured.annotate` attach any attribute to target ops, used to annotate reduction op and attach memory space to `allot_tensor` |
| 238 | + |
| 239 | +- `structured.tile` tile reduction op along `reduction` dimension to sequential for loop |
| 240 | + |
| 241 | +- `structured.detensorize` use to inline computation region of linalg op which operands have scalar tensor type |
| 242 | + |
| 243 | +- `LinalgCollapseLoopsPass` collapse consecutive `parallel` and `reduction` loops, this pass could work on both tensor and memref |
| 244 | + |
| 245 | +- `TensorPadSpecializationPass` specialize `tensor.extract` of pad op to conditional read |
0 commit comments