Skip to content

Commit 5eb5b61

Browse files
huiguoopytorchmergebot
authored andcommitted
[tensorexpre] Add typecast when src and dest buf types are different in PlacementAllocate (pytorch#71934)
Summary: Pull Request resolved: pytorch#71934 Test Plan: Imported from OSS Reviewed By: ZolotukhinM Differential Revision: D33826700 Pulled By: huiguoo fbshipit-source-id: 9fb29a43ab5983586a6bfde3a34d7e2f2120ab0a (cherry picked from commit 2bee018)
1 parent 555b215 commit 5eb5b61

File tree

3 files changed

+238
-3
lines changed

3 files changed

+238
-3
lines changed

test/cpp/tensorexpr/test_memplanning.cpp

+228
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include <gtest/gtest.h>
22
#include <test/cpp/tensorexpr/test_base.h>
33

4+
#include <c10/util/irange.h>
5+
#include <test/cpp/tensorexpr/padded_buffer.h>
46
#include <torch/csrc/jit/tensorexpr/ir.h>
57
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
68
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
@@ -85,6 +87,232 @@ TEST(BufLiveRange, MulRangeLine) {
8587
ASSERT_TRUE(std::get<1>(range_b) == 1);
8688
}
8789

90+
TEST(MemPlanning, MemReuseWithTypeCast) {
91+
int M = 4;
92+
int N = 4;
93+
int K = 4;
94+
95+
BufHandle AP("A", {M, K}, kFloat);
96+
BufHandle BP("B", {K, N}, kFloat);
97+
98+
Tensor CT = Reduce(
99+
"gemm",
100+
{M, N},
101+
Sum(),
102+
[&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
103+
return AP.load(m, k) * BP.load(k, n);
104+
},
105+
{K});
106+
Tensor DT =
107+
Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
108+
return CompareSelect::make(
109+
CT.load(m, n), 0.0f, 0.0f, CT.load(m, n), kLT);
110+
});
111+
Tensor ET =
112+
Compute("E", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
113+
return Cast::make(kQUInt8, DT.load(m, n) + DT.load(m, n));
114+
});
115+
Tensor FT =
116+
Compute("F", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
117+
return ET.load(m, n);
118+
});
119+
StmtPtr stmt =
120+
tensorexpr::Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()});
121+
122+
// Constructed stmt:
123+
// Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2],
124+
// E [2, 3]. The dimensions of 'gemm' and 'E' are the same but their types are
125+
// different: 'E' type quint8 < 'gemm' type float. We'll reuse 'gemm' for 'E'
126+
// with typecasting.
127+
//{
128+
// for (int i = 0; i < 4; i++) {
129+
// for (int i_1 = 0; i_1 < 4; i_1++) {
130+
// gemm[i, i_1] = float(0);
131+
// for (int i_2 = 0; i_2 < 4; i_2++) {
132+
// gemm[i, i_1] = ReduceOp((gemm[i, i_1]) + (A[i, i_2]) * (B[i_2,
133+
// i_1]), reduce_args={i_2});
134+
// }
135+
// }
136+
// }
137+
// for (int i_3 = 0; i_3 < 4; i_3++) {
138+
// for (int i_4 = 0; i_4 < 4; i_4++) {
139+
// relu[i_3, i_4] = (gemm[i_3, i_4])<0.f ? 0.f : (gemm[i_3, i_4]);
140+
// }
141+
// }
142+
// for (int i_5 = 0; i_5 < 4; i_5++) {
143+
// for (int i_6 = 0; i_6 < 4; i_6++) {
144+
// E[i_5, i_6] = quint8((relu[i_5, i_6]) + (relu[i_5, i_6]));
145+
// }
146+
// }
147+
// for (int i_7 = 0; i_7 < 4; i_7++) {
148+
// for (int i_8 = 0; i_8 < 4; i_8++) {
149+
// F[i_7, i_8] = E[i_7, i_8];
150+
// }
151+
// }
152+
//}
153+
154+
LoopNest l(stmt, {FT.buf()});
155+
l.prepareForCodegen();
156+
SimpleIREvaluator cg(Stmt::clone(l.root_stmt()), {AP, BP, FT});
157+
158+
checkIR(cg.stmt(), R"IR(
159+
# CHECK: Allocate(gemm); // dtype=float, dims=[4, 4]
160+
# CHECK: Allocate(relu); // dtype=float, dims=[4, 4]
161+
# CHECK: Alias(E,gemm);
162+
# CHECK: Free(relu);
163+
# CHECK: Free(gemm))IR");
164+
165+
PaddedBuffer<float> a_v(M, K, "a");
166+
PaddedBuffer<float> b_v(K, N, "b");
167+
PaddedBuffer<uint8_t> o1(M, N, "e_before");
168+
PaddedBuffer<uint8_t> o2(M, N, "e_after");
169+
170+
for (const auto m : c10::irange(M)) {
171+
for (const auto k : c10::irange(K)) {
172+
a_v(m, k) = at::randn({1}).item().to<float>();
173+
}
174+
}
175+
176+
for (const auto k : c10::irange(K)) {
177+
for (const auto n : c10::irange(N)) {
178+
b_v(k, n) = at::randn({1}).item().to<float>();
179+
}
180+
}
181+
182+
cg.call({a_v, b_v, o1});
183+
184+
#ifdef TORCH_ENABLE_LLVM
185+
LLVMCodeGen cg_llvm(Stmt::clone(l.root_stmt()), {AP, BP, FT});
186+
187+
checkIR(cg_llvm.stmt(), R"IR(
188+
# CHECK: Allocate(gemm); // dtype=float, dims=[4, 4]
189+
# CHECK: Allocate(relu); // dtype=float, dims=[4, 4]
190+
# CHECK: Alias(E,gemm);
191+
# CHECK: Free(relu);
192+
# CHECK: Free(gemm))IR");
193+
194+
cg_llvm.call({a_v, b_v, o2});
195+
196+
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
197+
ExpectAllNear(o1, o2, 1e-5);
198+
#endif
199+
}
200+
201+
TEST(MemPlanning, NoMemReuseForLargerType) {
202+
int M = 4;
203+
int N = 4;
204+
int K = 4;
205+
206+
BufHandle AP("A", {M, K}, kShort);
207+
BufHandle BP("B", {K, N}, kShort);
208+
209+
Tensor CT = Reduce(
210+
"gemm",
211+
{M, N},
212+
Sum(),
213+
[&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
214+
return AP.load(m, k) * BP.load(k, n);
215+
},
216+
{K});
217+
auto zero = Cast::make(CT.buf()->dtype(), 0);
218+
Tensor DT =
219+
Compute("relu", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
220+
return CompareSelect::make(
221+
CT.load(m, n), zero, zero, CT.load(m, n), kLT);
222+
});
223+
Tensor ET =
224+
Compute("E", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
225+
return Cast::make(kFloat, DT.load(m, n) + DT.load(m, n));
226+
});
227+
Tensor FT =
228+
Compute("F", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
229+
return ET.load(m, n);
230+
});
231+
StmtPtr stmt =
232+
tensorexpr::Block::make({CT.stmt(), DT.stmt(), ET.stmt(), FT.stmt()});
233+
234+
// Constructed stmt:
235+
// Intermediate buffers and their liveness ranges: gemm [0, 1], relu [1, 2],
236+
// E [2, 3]. The dimensions of 'gemm' and 'E' are the same but their types are
237+
// different: 'E' type float > 'gemm' type int16. We won't reuse 'gemm' for
238+
// 'E'.
239+
//{
240+
// for (int i = 0; i < 4; i++) {
241+
// for (int i_1 = 0; i_1 < 4; i_1++) {
242+
// gemm[i, i_1] = int16_t(0);
243+
// for (int i_2 = 0; i_2 < 4; i_2++) {
244+
// gemm[i, i_1] = ReduceOp((gemm[i, i_1]) + (A[i, i_2]) * (B[i_2,
245+
// i_1]), reduce_args={i_2});
246+
// }
247+
// }
248+
// }
249+
// for (int i_3 = 0; i_3 < 4; i_3++) {
250+
// for (int i_4 = 0; i_4 < 4; i_4++) {
251+
// relu[i_3, i_4] = (gemm[i_3, i_4])<int16_t(0) ? int16_t(0) : (gemm[i_3,
252+
// i_4]);
253+
// }
254+
// }
255+
// for (int i_5 = 0; i_5 < 4; i_5++) {
256+
// for (int i_6 = 0; i_6 < 4; i_6++) {
257+
// E[i_5, i_6] = float((relu[i_5, i_6]) + (relu[i_5, i_6]));
258+
// }
259+
// }
260+
// for (int i_7 = 0; i_7 < 4; i_7++) {
261+
// for (int i_8 = 0; i_8 < 4; i_8++) {
262+
// F[i_7, i_8] = E[i_7, i_8];
263+
// }
264+
// }
265+
//}
266+
267+
LoopNest l(stmt, {FT.buf()});
268+
l.prepareForCodegen();
269+
SimpleIREvaluator cg(Stmt::clone(l.root_stmt()), {AP, BP, FT.buf()});
270+
271+
checkIR(cg.stmt(), R"IR(
272+
# CHECK: Allocate(gemm); // dtype=int16_t, dims=[4, 4]
273+
# CHECK: Allocate(relu); // dtype=int16_t, dims=[4, 4]
274+
# CHECK: Allocate(E); // dtype=float, dims=[4, 4]
275+
# CHECK: Free(E);
276+
# CHECK: Free(relu);
277+
# CHECK: Free(gemm))IR");
278+
279+
PaddedBuffer<short> a_v(M, K, "a");
280+
PaddedBuffer<short> b_v(K, N, "b");
281+
PaddedBuffer<float> o1(M, N, "e_before");
282+
PaddedBuffer<float> o2(M, N, "e_after");
283+
284+
for (const auto m : c10::irange(M)) {
285+
for (const auto k : c10::irange(K)) {
286+
a_v(m, k) = at::randn({1}).item().to<float>();
287+
}
288+
}
289+
290+
for (const auto k : c10::irange(K)) {
291+
for (const auto n : c10::irange(N)) {
292+
b_v(k, n) = at::randn({1}).item().to<float>();
293+
}
294+
}
295+
296+
cg.call({a_v, b_v, o1});
297+
298+
#ifdef TORCH_ENABLE_LLVM
299+
LLVMCodeGen cg_llvm(Stmt::clone(l.root_stmt()), {AP, BP, FT});
300+
301+
checkIR(cg_llvm.stmt(), R"IR(
302+
# CHECK: Allocate(gemm); // dtype=int16_t, dims=[4, 4]
303+
# CHECK: Allocate(relu); // dtype=int16_t, dims=[4, 4]
304+
# CHECK: Allocate(E); // dtype=float, dims=[4, 4]
305+
# CHECK: Free(E);
306+
# CHECK: Free(relu);
307+
# CHECK: Free(gemm))IR");
308+
309+
cg_llvm.call({a_v, b_v, o2});
310+
311+
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
312+
ExpectAllNear(o1, o2, 1e-5);
313+
#endif
314+
}
315+
88316
TEST(MemPlanning, SameBufSizeMemReuse) {
89317
int M = 1024;
90318
int N = 1024;

torch/csrc/jit/tensorexpr/eval.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -983,7 +983,7 @@ class SimpleIREvaluatorImpl : public IRVisitor {
983983
}
984984

985985
void visit(PlacementAllocatePtr v) override {
986-
buffer_mapping_[v->buf()] = buffer_mapping_[v->buf_to_reuse()];
986+
buffer_mapping_[v->buf()] = buffer_mapping_.at(v->buf_to_reuse());
987987
}
988988

989989
void visit(FreePtr v) override {

torch/csrc/jit/tensorexpr/llvm_codegen.cpp

+9-2
Original file line numberDiff line numberDiff line change
@@ -2063,8 +2063,15 @@ void LLVMCodeGenImpl::visit(AllocatePtr v) {
20632063
}
20642064

20652065
void LLVMCodeGenImpl::visit(PlacementAllocatePtr v) {
2066-
llvm::Value* ptr = varToVal_.at(v->buf_to_reuse()->base_handle());
2067-
varToVal_[v->buf()->base_handle()] = ptr;
2066+
auto buf_to_reuse = v->buf_to_reuse();
2067+
auto buf = v->buf();
2068+
2069+
llvm::Value* ptr = varToVal_.at(buf_to_reuse->base_handle());
2070+
if (buf_to_reuse->dtype().scalar_type() != buf->dtype().scalar_type()) {
2071+
ptr = irb_.CreatePointerCast(ptr, dtypeToLLVMPtr(buf->dtype()));
2072+
}
2073+
2074+
varToVal_[buf->base_handle()] = ptr;
20682075
}
20692076

20702077
void LLVMCodeGenImpl::visit(FreePtr v) {

0 commit comments

Comments
 (0)