Skip to content

Commit 429a80d

Browse files
EikanWangpytorchmergebot
authored andcommitted
[NNC] Lowering function generates the output buffer with the specified stride (pytorch#76529)
Summary: Pass stride information to lowering function to generate the output bufer with proper memory layout. Pull Request resolved: pytorch#76529 Reviewed By: ZolotukhinM Differential Revision: D36116712 Pulled By: IvanKobzarev fbshipit-source-id: d3901f756b3710ecce172d6db3ecb0b7c12fb929 (cherry picked from commit b6cd53c)
1 parent 0878ba4 commit 429a80d

30 files changed

+687
-83
lines changed

aten/src/ATen/core/jit_type.h

+28-7
Original file line numberDiff line numberDiff line change
@@ -786,15 +786,36 @@ struct TORCH_API TensorType : public SharedType {
786786

787787
static const TypeKind Kind = TypeKind::TensorType;
788788

789-
static std::vector<int64_t> contiguousStridesOf(at::IntArrayRef sizes) {
790-
std::vector<int64_t> strides(sizes.size());
791-
if (sizes.empty()) // zero-dim case
789+
static std::vector<int64_t> contiguousStridesOf(
790+
at::IntArrayRef sizes,
791+
at::MemoryFormat memory_format = MemoryFormat::Contiguous) {
792+
auto contiguous_fn = [](const at::IntArrayRef& sizes,
793+
const std::vector<int64_t>& dim_order) {
794+
std::vector<int64_t> strides(sizes.size());
795+
if (sizes.empty()) // zero-dim case
796+
return strides;
797+
798+
strides[dim_order[0]] = 1;
799+
for (size_t i = 1; i < dim_order.size(); i++) {
800+
auto cur_dim = dim_order[i];
801+
auto pre_dim = dim_order[i - 1];
802+
strides[cur_dim] = strides[pre_dim] * sizes[pre_dim];
803+
}
792804
return strides;
793-
strides.back() = 1;
794-
for (size_t i = strides.size() - 1; i > 0; i--) {
795-
strides[i - 1] = strides[i] * sizes[i];
805+
};
806+
807+
std::vector<int64_t> dim_order(sizes.size());
808+
if (memory_format == MemoryFormat::ChannelsLast) {
809+
dim_order = {1, 3, 2, 0};
810+
} else if (memory_format == MemoryFormat::ChannelsLast3d) {
811+
dim_order = {1, 4, 3, 2, 0};
812+
} else {
813+
auto ndims = sizes.size();
814+
for (size_t i = 0; i < ndims; i++) {
815+
dim_order[i] = ndims - i - 1; // Reverse
816+
}
796817
}
797-
return strides;
818+
return contiguous_fn(sizes, dim_order);
798819
}
799820

800821
private:

benchmarks/cpp/tensorexpr/bench_reduce.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -392,8 +392,8 @@ BENCHMARK_DEFINE_F(Reduce1D, Op)(benchmark::State& state) {
392392
const int kChunkSize = 8;
393393

394394
te::BufHandle a("A", {M}, te::kFloat);
395-
te::Tensor b =
396-
te::computeSum({a, te::IntList({0}), false}, {}, at::kFloat, at::kCPU);
395+
te::Tensor b = te::computeSum(
396+
{a, te::IntList({0}), false}, {}, {}, at::kFloat, at::kCPU);
397397
te::LoopNest nest({b});
398398

399399
auto loops = nest.getLoopStmtsFor(b);
@@ -456,8 +456,8 @@ BENCHMARK_REGISTER_F(Reduce2DCol, Torch)
456456
BENCHMARK_DEFINE_F(Reduce2DCol, OpSchedule)(benchmark::State& state) {
457457
constexpr int kCacheSize = 1 << 12;
458458
te::BufHandle a("A", {M, N}, te::kFloat);
459-
te::Tensor b =
460-
te::computeSum({a, te::IntList({0}), false}, {N}, at::kFloat, at::kCPU);
459+
te::Tensor b = te::computeSum(
460+
{a, te::IntList({0}), false}, {N}, {1}, at::kFloat, at::kCPU);
461461
te::LoopNest nest({b});
462462

463463
auto sch = state.range(2);
@@ -565,8 +565,8 @@ BENCHMARK_REGISTER_F(Reduce2DRow, Hand)->Args({1 << 18, 1 << 6});
565565
BENCHMARK_DEFINE_F(Reduce2DRow, OpSchedule)(benchmark::State& state) {
566566
constexpr int kChunkSize = 8;
567567
te::BufHandle a("A", {M, N}, te::kFloat);
568-
te::Tensor b =
569-
te::computeSum({a, te::IntList({1}), false}, {M}, at::kFloat, at::kCPU);
568+
te::Tensor b = te::computeSum(
569+
{a, te::IntList({1}), false}, {M}, {1}, at::kFloat, at::kCPU);
570570
te::LoopNest nest({b});
571571

572572
auto sch = state.range(2);

test/cpp/tensorexpr/test_external_calls.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -944,6 +944,7 @@ TEST(ExternalCall, JitCustomFusionOp) {
944944
[external_func_name](
945945
const std::vector<torch::jit::tensorexpr::ArgValue>& inputs,
946946
const std::vector<torch::jit::tensorexpr::ExprHandle>& output_shape,
947+
const std::vector<torch::jit::tensorexpr::ExprHandle>& output_strides,
947948
const c10::optional<torch::jit::tensorexpr::ScalarType>& output_type,
948949
at::Device device) {
949950
auto output_dtype = Dtype(*output_type);

test/cpp/tensorexpr/test_kernel.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -1598,12 +1598,14 @@ TEST_F(Kernel, CodegenInspection) {
15981598
Tensor lowerNanToNum(
15991599
const std::vector<ArgValue>& inputs,
16001600
const std::vector<ExprHandle>& outputShape,
1601+
const std::vector<ExprHandle>& outputStrides,
16011602
const c10::optional<ScalarType>& outputType,
16021603
at::Device device) {
16031604
auto input_buf = c10::get<BufHandle>(inputs[0]);
16041605
auto e = Compute(
16051606
"custom_nan_to_num",
16061607
outputShape,
1608+
outputStrides,
16071609
[&](const std::vector<VarHandle>& axes) {
16081610
std::vector<ExprHandle> indices(axes.begin(), axes.end());
16091611
auto load = input_buf.load(indices);

test/cpp/tensorexpr/test_ops.cpp

+36-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <gtest/gtest.h>
22
#include <torch/csrc/jit/tensorexpr/eval.h>
3+
#include <torch/csrc/jit/tensorexpr/expr.h>
34
#include <torch/csrc/jit/tensorexpr/loopnest.h>
45
#include <torch/csrc/jit/tensorexpr/operators/operators.h>
56
#include <torch/torch.h>
@@ -29,7 +30,10 @@ TEST(Ops, Sum) {
2930
const auto& outShape = outputShapes[idx];
3031

3132
BufHandle a("a", {M, N}, kFloat);
32-
Tensor b = computeSum({a, dims, false}, outShape, c10::kFloat, at::kCPU);
33+
std::vector<ExprHandle> outStrides =
34+
c10::fmap<ExprHandle>(make_contiguous_strides(outShape));
35+
Tensor b = computeSum(
36+
{a, dims, false}, outShape, outStrides, c10::kFloat, at::kCPU);
3337
auto cg = compile({a}, {b});
3438

3539
auto at = at::arange(M * N, at::kFloat).view({M, N});
@@ -41,3 +45,34 @@ TEST(Ops, Sum) {
4145
ASSERT_TRUE(at::allclose(bt, ref));
4246
}
4347
}
48+
49+
TEST(Ops, ChannelsLastSum) {
50+
constexpr int A = 2;
51+
constexpr int B = 3;
52+
constexpr int C = 4;
53+
constexpr int D = 5;
54+
constexpr int E = 6;
55+
std::vector<IntList> testDims = {{0}, {1}, {0, 1}};
56+
57+
std::vector<std::vector<ExprHandle>> outputShapes = {
58+
{B, C, D, E}, {A, C, D, E}, {C, D, E}};
59+
for (unsigned idx = 0; idx < testDims.size(); idx++) {
60+
const auto& dims = testDims[idx];
61+
const auto& outShape = outputShapes[idx];
62+
63+
BufHandle a("a", {A, B, C, D, E}, kFloat);
64+
std::vector<ExprHandle> outStrides =
65+
c10::fmap<ExprHandle>(make_channels_last_strides(outShape));
66+
Tensor b = computeSum(
67+
{a, dims, false}, outShape, outStrides, c10::kFloat, at::kCPU);
68+
auto cg = compile({a}, {b});
69+
70+
auto at = at::arange(A * B * C * D * E, at::kFloat).view({A, B, C, D, E});
71+
auto ref = at::sum(at, dims);
72+
auto bt = at::empty_like(ref);
73+
74+
cg->call({at.data_ptr<float>(), bt.data_ptr<float>()});
75+
76+
ASSERT_TRUE(at::allclose(bt, ref));
77+
}
78+
}

test/test_tensorexpr_pybind.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def f(a):
348348
"""
349349
graph = torch._C.parse_ir(graph_str)
350350

351-
def my_custom_lowering(inputs, out_shape, out_type, device):
351+
def my_custom_lowering(inputs, out_shape, out_stride, out_type, device):
352352
def compute(idxs):
353353
load = inputs[0].as_buf().load(idxs)
354354
return te.ifThenElse(

torch/csrc/jit/tensorexpr/expr.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,11 @@ bool Buf::is_contiguous(at::MemoryFormat memory_format) const {
477477
return false;
478478
dim_order = {1, 4, 3, 2, 0};
479479
} else {
480+
if (dims_.empty()) {
481+
// Scalar tensor
482+
TORCH_CHECK(strides_.empty());
483+
return true; // Align with the isContiguous logic in the kernel.cpp
484+
}
480485
for (size_t i = 0; i < ndims; i++) {
481486
dim_order[i] = ndims - i - 1; // Reverse
482487
}

torch/csrc/jit/tensorexpr/expr.h

+2-3
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,9 @@ class TORCH_API Var : public ExprNode<Var> {
186186
std::string name_hint_;
187187
};
188188

189-
std::vector<ExprPtr> make_contiguous_strides(
189+
TORCH_API std::vector<ExprPtr> make_contiguous_strides(
190190
const std::vector<ExprHandle>& dims);
191-
std::vector<ExprPtr> make_channels_last_strides(
191+
TORCH_API std::vector<ExprPtr> make_channels_last_strides(
192192
const std::vector<ExprHandle>& dims);
193193

194194
class TORCH_API Buf : public ExprNode<Buf> {
@@ -324,7 +324,6 @@ class TORCH_API Buf : public ExprNode<Buf> {
324324
bool is_cont_with(int cur_dim, int adjacent_dim) const;
325325
bool is_stride_one(int cur_dim) const;
326326

327-
private:
328327
VarPtr base_handle_;
329328
std::vector<ExprPtr> dims_;
330329
std::vector<ExprPtr> strides_;

torch/csrc/jit/tensorexpr/kernel.cpp

+79-14
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,9 @@ std::vector<int64_t> _pair_int(IValue v) {
208208
}
209209
}
210210

211-
static bool isContiguous(const torch::jit::Value* v) {
211+
static bool isContiguous(
212+
const torch::jit::Value* v,
213+
at::MemoryFormat memory_format = at::MemoryFormat::Contiguous) {
212214
auto const& tt = v->type()->cast<TensorType>();
213215
if (!tt) {
214216
return false;
@@ -221,6 +223,14 @@ static bool isContiguous(const torch::jit::Value* v) {
221223
if (!sizes || !strides) {
222224
return false;
223225
}
226+
227+
// Check dimension size first
228+
int ndims = (*sizes).size();
229+
if ((memory_format == at::MemoryFormat::ChannelsLast && ndims != 4) ||
230+
(memory_format == at::MemoryFormat::ChannelsLast3d && ndims != 5)) {
231+
return false;
232+
}
233+
224234
return *strides == TensorType::contiguousStridesOf(*sizes);
225235
}
226236

@@ -475,8 +485,38 @@ Tensor TensorExprKernel::computeValue(const torch::jit::Value* v) {
475485
hasRandom_ = true;
476486
}
477487

488+
// Check if the tensor is a contiguous tensor
489+
bool is_contiguous = false;
490+
// Check if the tensor is a channels-last contiguous tensor
491+
bool is_channels_last_contiguous = false;
492+
for (auto input : inputs) {
493+
if (input->type()->kind() != TypeKind::TensorType)
494+
continue;
495+
496+
TORCH_CHECK(bufs_.count(input) > 0);
497+
auto buf_ = bufs_.at(input);
498+
499+
auto _is_contiguous = buf_->is_contiguous();
500+
if (_is_contiguous) {
501+
is_contiguous |= _is_contiguous;
502+
} else {
503+
is_channels_last_contiguous |=
504+
(buf_->is_contiguous(at::MemoryFormat::ChannelsLast) ||
505+
buf_->is_contiguous(at::MemoryFormat::ChannelsLast3d) ||
506+
buf_->is_channels_last_1d_contiguous());
507+
}
508+
}
509+
478510
auto outputType = findDtypeForValue(v);
479511
std::vector<ExprHandle> outputShape = sizesForValue(v);
512+
std::vector<ExprHandle> outputStrides;
513+
if (is_channels_last_contiguous && (!is_contiguous)) {
514+
outputStrides =
515+
c10::fmap<ExprHandle>(make_channels_last_strides(outputShape));
516+
} else {
517+
// Default
518+
outputStrides = c10::fmap<ExprHandle>(make_contiguous_strides(outputShape));
519+
}
480520

481521
std::vector<ArgValue> argInputs;
482522
if (op == prim::ConstantChunk) {
@@ -521,12 +561,14 @@ Tensor TensorExprKernel::computeValue(const torch::jit::Value* v) {
521561
}
522562

523563
if (NNCLoweringFunction custom_lowering = getCustomLoweringFor(op)) {
524-
return custom_lowering(argInputs, outputShape, outputType, device_);
564+
return custom_lowering(
565+
argInputs, outputShape, outputStrides, outputType, device_);
525566
}
526567
if (v->node()->maybeSchema()) {
527568
if (NNCLoweringFunction lowering =
528569
getStandardLoweringFor(c10::toString(v->node()->schema()))) {
529-
return lowering(argInputs, outputShape, outputType, device_);
570+
return lowering(
571+
argInputs, outputShape, outputStrides, outputType, device_);
530572
}
531573
}
532574
std::string msg = std::string("Unhandled node kind (in computeValue): ") +
@@ -995,28 +1037,53 @@ Tensor TensorExprKernel::bindInput(const torch::jit::Value* input) {
9951037
auto const& outputs = input->owningGraph()->outputs();
9961038
std::unordered_set<const Value*> outputs_set(outputs.begin(), outputs.end());
9971039

1040+
auto is_concrete_cont = [](const torch::jit::Value* input) {
1041+
if (input->isCompleteTensor()) {
1042+
return isContiguous(input);
1043+
} else {
1044+
return false;
1045+
}
1046+
};
1047+
1048+
auto is_symbolic_cont = [](std::vector<torch::jit::StrideInput> desc) {
1049+
if (desc.size() == 1) {
1050+
return desc[0] == torch::jit::StrideInput::TENSOR_CONT;
1051+
} else {
1052+
return false;
1053+
}
1054+
};
1055+
9981056
Tensor result(nullptr, nullptr);
9991057
switch (t->kind()) {
10001058
case TypeKind::TensorType: {
10011059
auto tt = input->type()->cast<TensorType>();
1002-
bool contiguous_concrete_tensor =
1003-
(input->isCompleteTensor() && isContiguous(input));
1004-
bool contiguous_sym_tensor = false;
1060+
bool contiguous_concrete_tensor = is_concrete_cont(input);
1061+
bool contiguous_symbolic_tensor = false;
10051062
if (has_symbolic_shapes_) {
10061063
auto desc = getSymbolicInputStrideDesc(input);
1007-
contiguous_sym_tensor =
1008-
desc.size() == 1 && desc[0] == torch::jit::StrideInput::TENSOR_CONT;
1064+
contiguous_symbolic_tensor = is_symbolic_cont(desc);
10091065
}
10101066

1067+
// Get input size and strides
1068+
auto size_handles = sizesFromSymbolicShape(tt->symbolic_sizes());
1069+
auto inputTensorStrides = getInputStrides(input, size_handles);
1070+
10111071
// We don't need to copy the input if:
10121072
// 1) it is not an output AND
10131073
// 2) it is contiguous
1014-
bool contiguous = contiguous_concrete_tensor || contiguous_sym_tensor;
1074+
bool contiguous =
1075+
contiguous_concrete_tensor || contiguous_symbolic_tensor;
10151076
if (!outputs_set.count(input) && contiguous) {
10161077
BufHandle inBuffer(
10171078
"t" + input_name_map_[input],
10181079
sizesFromSymbolicShape(tt->symbolic_sizes()),
1080+
inputTensorStrides,
10191081
ToDtype(static_cast<ScalarType>(*tt->scalarType())));
1082+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
1083+
inBuffer.node()->is_contiguous() ||
1084+
inBuffer.node()->is_channels_last_1d_contiguous() ||
1085+
inBuffer.node()->is_contiguous(at::MemoryFormat::ChannelsLast) ||
1086+
inBuffer.node()->is_contiguous(at::MemoryFormat::ChannelsLast3d));
10201087
bufs_.emplace(input, inBuffer.node());
10211088
bufferArgs_.emplace_back(inBuffer);
10221089
break;
@@ -1025,8 +1092,6 @@ Tensor TensorExprKernel::bindInput(const torch::jit::Value* input) {
10251092
// if the input isn't contiguous or is an output,
10261093
// write strided input into contiguous buffer that is
10271094
// then used in all further compute
1028-
auto size_handles = sizesFromSymbolicShape(tt->symbolic_sizes());
1029-
auto inputTensorStrides = getInputStrides(input, size_handles);
10301095
ExprHandle flat_size = 1;
10311096
for (size_t i = 0; i < size_handles.size(); ++i) {
10321097
auto size = size_handles[i];
@@ -1168,11 +1233,11 @@ Tensor TensorExprKernel::convertSymbolicOutputToCorrectStrides(
11681233
"Ouput tensor has no corresponding bufs in the fuser."));
11691234
BufPtr buf = bufs_.at(v);
11701235
// output is contiguous, no work to do
1171-
if (tensorOutputStrideDesc_[v->offset()] ==
1172-
torch::jit::StrideInput::TENSOR_CONT) {
1236+
auto stride_desc = tensorOutputStrideDesc_[v->offset()];
1237+
if (stride_desc == torch::jit::StrideInput::TENSOR_CONT) {
11731238
return Tensor(buf, nullptr);
1174-
;
11751239
}
1240+
11761241
TORCH_INTERNAL_ASSERT(
11771242
tensorOutputStrideDesc_[v->offset()] ==
11781243
torch::jit::StrideInput::TENSOR_CONT_CHANNELS_LAST);

0 commit comments

Comments
 (0)