|
1 | 1 | #include <gtest/gtest.h>
|
2 | 2 | #include <test/cpp/tensorexpr/test_base.h>
|
3 | 3 |
|
| 4 | +#include <c10/util/irange.h> |
| 5 | +#include <test/cpp/tensorexpr/padded_buffer.h> |
4 | 6 | #include <torch/csrc/jit/tensorexpr/ir.h>
|
5 | 7 | #include <torch/csrc/jit/tensorexpr/ir_printer.h>
|
6 | 8 | #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
|
@@ -85,6 +87,232 @@ TEST(BufLiveRange, MulRangeLine) {
|
85 | 87 | ASSERT_TRUE(std::get<1>(range_b) == 1);
|
86 | 88 | }
|
87 | 89 |
|
| 90 | +TEST(MemPlanning, MemReuseWithTypeCast) { |
| 91 | + int M = 4; |
| 92 | + int N = 4; |
| 93 | + int K = 4; |
| 94 | + |
| 95 | + BufHandle AP("A", {M, K}, kFloat); |
| 96 | + BufHandle BP("B", {K, N}, kFloat); |
| 97 | + |
| 98 | + Tensor CT = Reduce( |
| 99 | + "gemm", |
| 100 | + {M, N}, |
| 101 | + Sum(), |
| 102 | + [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { |
| 103 | + return AP.load(m, k) * BP.load(k, n); |
| 104 | + }, |
| 105 | + {K}); |
| 106 | + Tensor DT = |
| 107 | + Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { |
| 108 | + return CompareSelect::make( |
| 109 | + CT.load(m, n), 0.0f, 0.0f, CT.load(m, n), kLT); |
| 110 | + }); |
| 111 | + Tensor ET = |
| 112 | + Compute("E", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { |
| 113 | + return Cast::make(kQUInt8, DT.load(m, n) + DT.load(m, n)); |
| 114 | + }); |
| 115 | + Tensor FT = |
| 116 | + Compute("F", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { |
| 117 | + return ET.load(m, n); |
| 118 | + }); |
| 119 | + StmtPtr stmt = |
| 120 | + tensorexpr::Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()}); |
| 121 | + |
| 122 | + // Constructed stmt: |
| 123 | + // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], |
| 124 | + // E [2, 3]. The dimensions of 'gemm' and 'E' are the same but their types are |
| 125 | + // different: 'E' type quint8 < 'gemm' type float. We'll reuse 'gemm' for 'E' |
| 126 | + // with typecasting. |
| 127 | + //{ |
| 128 | + // for (int i = 0; i < 4; i++) { |
| 129 | + // for (int i_1 = 0; i_1 < 4; i_1++) { |
| 130 | + // gemm[i, i_1] = float(0); |
| 131 | + // for (int i_2 = 0; i_2 < 4; i_2++) { |
| 132 | + // gemm[i, i_1] = ReduceOp((gemm[i, i_1]) + (A[i, i_2]) * (B[i_2, |
| 133 | + // i_1]), reduce_args={i_2}); |
| 134 | + // } |
| 135 | + // } |
| 136 | + // } |
| 137 | + // for (int i_3 = 0; i_3 < 4; i_3++) { |
| 138 | + // for (int i_4 = 0; i_4 < 4; i_4++) { |
| 139 | + // relu[i_3, i_4] = (gemm[i_3, i_4])<0.f ? 0.f : (gemm[i_3, i_4]); |
| 140 | + // } |
| 141 | + // } |
| 142 | + // for (int i_5 = 0; i_5 < 4; i_5++) { |
| 143 | + // for (int i_6 = 0; i_6 < 4; i_6++) { |
| 144 | + // E[i_5, i_6] = quint8((relu[i_5, i_6]) + (relu[i_5, i_6])); |
| 145 | + // } |
| 146 | + // } |
| 147 | + // for (int i_7 = 0; i_7 < 4; i_7++) { |
| 148 | + // for (int i_8 = 0; i_8 < 4; i_8++) { |
| 149 | + // F[i_7, i_8] = E[i_7, i_8]; |
| 150 | + // } |
| 151 | + // } |
| 152 | + //} |
| 153 | + |
| 154 | + LoopNest l(stmt, {FT.buf()}); |
| 155 | + l.prepareForCodegen(); |
| 156 | + SimpleIREvaluator cg(Stmt::clone(l.root_stmt()), {AP, BP, FT}); |
| 157 | + |
| 158 | + checkIR(cg.stmt(), R"IR( |
| 159 | +# CHECK: Allocate(gemm); // dtype=float, dims=[4, 4] |
| 160 | +# CHECK: Allocate(relu); // dtype=float, dims=[4, 4] |
| 161 | +# CHECK: Alias(E,gemm); |
| 162 | +# CHECK: Free(relu); |
| 163 | +# CHECK: Free(gemm))IR"); |
| 164 | + |
| 165 | + PaddedBuffer<float> a_v(M, K, "a"); |
| 166 | + PaddedBuffer<float> b_v(K, N, "b"); |
| 167 | + PaddedBuffer<uint8_t> o1(M, N, "e_before"); |
| 168 | + PaddedBuffer<uint8_t> o2(M, N, "e_after"); |
| 169 | + |
| 170 | + for (const auto m : c10::irange(M)) { |
| 171 | + for (const auto k : c10::irange(K)) { |
| 172 | + a_v(m, k) = at::randn({1}).item().to<float>(); |
| 173 | + } |
| 174 | + } |
| 175 | + |
| 176 | + for (const auto k : c10::irange(K)) { |
| 177 | + for (const auto n : c10::irange(N)) { |
| 178 | + b_v(k, n) = at::randn({1}).item().to<float>(); |
| 179 | + } |
| 180 | + } |
| 181 | + |
| 182 | + cg.call({a_v, b_v, o1}); |
| 183 | + |
| 184 | +#ifdef TORCH_ENABLE_LLVM |
| 185 | + LLVMCodeGen cg_llvm(Stmt::clone(l.root_stmt()), {AP, BP, FT}); |
| 186 | + |
| 187 | + checkIR(cg_llvm.stmt(), R"IR( |
| 188 | +# CHECK: Allocate(gemm); // dtype=float, dims=[4, 4] |
| 189 | +# CHECK: Allocate(relu); // dtype=float, dims=[4, 4] |
| 190 | +# CHECK: Alias(E,gemm); |
| 191 | +# CHECK: Free(relu); |
| 192 | +# CHECK: Free(gemm))IR"); |
| 193 | + |
| 194 | + cg_llvm.call({a_v, b_v, o2}); |
| 195 | + |
| 196 | + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| 197 | + ExpectAllNear(o1, o2, 1e-5); |
| 198 | +#endif |
| 199 | +} |
| 200 | + |
| 201 | +TEST(MemPlanning, NoMemReuseForLargerType) { |
| 202 | + int M = 4; |
| 203 | + int N = 4; |
| 204 | + int K = 4; |
| 205 | + |
| 206 | + BufHandle AP("A", {M, K}, kShort); |
| 207 | + BufHandle BP("B", {K, N}, kShort); |
| 208 | + |
| 209 | + Tensor CT = Reduce( |
| 210 | + "gemm", |
| 211 | + {M, N}, |
| 212 | + Sum(), |
| 213 | + [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) { |
| 214 | + return AP.load(m, k) * BP.load(k, n); |
| 215 | + }, |
| 216 | + {K}); |
| 217 | + auto zero = Cast::make(CT.buf()->dtype(), 0); |
| 218 | + Tensor DT = |
| 219 | + Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { |
| 220 | + return CompareSelect::make( |
| 221 | + CT.load(m, n), zero, zero, CT.load(m, n), kLT); |
| 222 | + }); |
| 223 | + Tensor ET = |
| 224 | + Compute("E", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { |
| 225 | + return Cast::make(kFloat, DT.load(m, n) + DT.load(m, n)); |
| 226 | + }); |
| 227 | + Tensor FT = |
| 228 | + Compute("F", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) { |
| 229 | + return ET.load(m, n); |
| 230 | + }); |
| 231 | + StmtPtr stmt = |
| 232 | + tensorexpr::Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()}); |
| 233 | + |
| 234 | + // Constructed stmt: |
| 235 | + // Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2], |
| 236 | + // E [2, 3]. The dimensions of 'gemm' and 'E' are the same but their types are |
| 237 | + // different: 'E' type float > 'gemm' type int16. We won't reuse 'gemm' for |
| 238 | + // 'E'. |
| 239 | + //{ |
| 240 | + // for (int i = 0; i < 4; i++) { |
| 241 | + // for (int i_1 = 0; i_1 < 4; i_1++) { |
| 242 | + // gemm[i, i_1] = int16_t(0); |
| 243 | + // for (int i_2 = 0; i_2 < 4; i_2++) { |
| 244 | + // gemm[i, i_1] = ReduceOp((gemm[i, i_1]) + (A[i, i_2]) * (B[i_2, |
| 245 | + // i_1]), reduce_args={i_2}); |
| 246 | + // } |
| 247 | + // } |
| 248 | + // } |
| 249 | + // for (int i_3 = 0; i_3 < 4; i_3++) { |
| 250 | + // for (int i_4 = 0; i_4 < 4; i_4++) { |
| 251 | + // relu[i_3, i_4] = (gemm[i_3, i_4])<int16_t(0) ? int16_t(0) : (gemm[i_3, |
| 252 | + // i_4]); |
| 253 | + // } |
| 254 | + // } |
| 255 | + // for (int i_5 = 0; i_5 < 4; i_5++) { |
| 256 | + // for (int i_6 = 0; i_6 < 4; i_6++) { |
| 257 | + // E[i_5, i_6] = float((relu[i_5, i_6]) + (relu[i_5, i_6])); |
| 258 | + // } |
| 259 | + // } |
| 260 | + // for (int i_7 = 0; i_7 < 4; i_7++) { |
| 261 | + // for (int i_8 = 0; i_8 < 4; i_8++) { |
| 262 | + // F[i_7, i_8] = E[i_7, i_8]; |
| 263 | + // } |
| 264 | + // } |
| 265 | + //} |
| 266 | + |
| 267 | + LoopNest l(stmt, {FT.buf()}); |
| 268 | + l.prepareForCodegen(); |
| 269 | + SimpleIREvaluator cg(Stmt::clone(l.root_stmt()), {AP, BP, FT.buf()}); |
| 270 | + |
| 271 | + checkIR(cg.stmt(), R"IR( |
| 272 | +# CHECK: Allocate(gemm); // dtype=int16_t, dims=[4, 4] |
| 273 | +# CHECK: Allocate(relu); // dtype=int16_t, dims=[4, 4] |
| 274 | +# CHECK: Allocate(E); // dtype=float, dims=[4, 4] |
| 275 | +# CHECK: Free(E); |
| 276 | +# CHECK: Free(relu); |
| 277 | +# CHECK: Free(gemm))IR"); |
| 278 | + |
| 279 | + PaddedBuffer<short> a_v(M, K, "a"); |
| 280 | + PaddedBuffer<short> b_v(K, N, "b"); |
| 281 | + PaddedBuffer<float> o1(M, N, "e_before"); |
| 282 | + PaddedBuffer<float> o2(M, N, "e_after"); |
| 283 | + |
| 284 | + for (const auto m : c10::irange(M)) { |
| 285 | + for (const auto k : c10::irange(K)) { |
| 286 | + a_v(m, k) = at::randn({1}).item().to<float>(); |
| 287 | + } |
| 288 | + } |
| 289 | + |
| 290 | + for (const auto k : c10::irange(K)) { |
| 291 | + for (const auto n : c10::irange(N)) { |
| 292 | + b_v(k, n) = at::randn({1}).item().to<float>(); |
| 293 | + } |
| 294 | + } |
| 295 | + |
| 296 | + cg.call({a_v, b_v, o1}); |
| 297 | + |
| 298 | +#ifdef TORCH_ENABLE_LLVM |
| 299 | + LLVMCodeGen cg_llvm(Stmt::clone(l.root_stmt()), {AP, BP, FT}); |
| 300 | + |
| 301 | + checkIR(cg_llvm.stmt(), R"IR( |
| 302 | +# CHECK: Allocate(gemm); // dtype=int16_t, dims=[4, 4] |
| 303 | +# CHECK: Allocate(relu); // dtype=int16_t, dims=[4, 4] |
| 304 | +# CHECK: Allocate(E); // dtype=float, dims=[4, 4] |
| 305 | +# CHECK: Free(E); |
| 306 | +# CHECK: Free(relu); |
| 307 | +# CHECK: Free(gemm))IR"); |
| 308 | + |
| 309 | + cg_llvm.call({a_v, b_v, o2}); |
| 310 | + |
| 311 | + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| 312 | + ExpectAllNear(o1, o2, 1e-5); |
| 313 | +#endif |
| 314 | +} |
| 315 | + |
88 | 316 | TEST(MemPlanning, SameBufSizeMemReuse) {
|
89 | 317 | int M = 1024;
|
90 | 318 | int N = 1024;
|
|
0 commit comments