@@ -208,7 +208,9 @@ std::vector<int64_t> _pair_int(IValue v) {
208
208
}
209
209
}
210
210
211
- static bool isContiguous (const torch::jit::Value* v) {
211
+ static bool isContiguous (
212
+ const torch::jit::Value* v,
213
+ at::MemoryFormat memory_format = at::MemoryFormat::Contiguous) {
212
214
auto const & tt = v->type ()->cast <TensorType>();
213
215
if (!tt) {
214
216
return false ;
@@ -221,6 +223,14 @@ static bool isContiguous(const torch::jit::Value* v) {
221
223
if (!sizes || !strides) {
222
224
return false ;
223
225
}
226
+
227
+ // Check dimension size first
228
+ int ndims = (*sizes).size ();
229
+ if ((memory_format == at::MemoryFormat::ChannelsLast && ndims != 4 ) ||
230
+ (memory_format == at::MemoryFormat::ChannelsLast3d && ndims != 5 )) {
231
+ return false ;
232
+ }
233
+
224
234
return *strides == TensorType::contiguousStridesOf (*sizes);
225
235
}
226
236
@@ -475,8 +485,38 @@ Tensor TensorExprKernel::computeValue(const torch::jit::Value* v) {
475
485
hasRandom_ = true ;
476
486
}
477
487
488
+ // Check if the tensor is a contiguous tensor
489
+ bool is_contiguous = false ;
490
+ // Check if the tensor is a channels-last contiguous tensor
491
+ bool is_channels_last_contiguous = false ;
492
+ for (auto input : inputs) {
493
+ if (input->type ()->kind () != TypeKind::TensorType)
494
+ continue ;
495
+
496
+ TORCH_CHECK (bufs_.count (input) > 0 );
497
+ auto buf_ = bufs_.at (input);
498
+
499
+ auto _is_contiguous = buf_->is_contiguous ();
500
+ if (_is_contiguous) {
501
+ is_contiguous |= _is_contiguous;
502
+ } else {
503
+ is_channels_last_contiguous |=
504
+ (buf_->is_contiguous (at::MemoryFormat::ChannelsLast) ||
505
+ buf_->is_contiguous (at::MemoryFormat::ChannelsLast3d) ||
506
+ buf_->is_channels_last_1d_contiguous ());
507
+ }
508
+ }
509
+
478
510
auto outputType = findDtypeForValue (v);
479
511
std::vector<ExprHandle> outputShape = sizesForValue (v);
512
+ std::vector<ExprHandle> outputStrides;
513
+ if (is_channels_last_contiguous && (!is_contiguous)) {
514
+ outputStrides =
515
+ c10::fmap<ExprHandle>(make_channels_last_strides (outputShape));
516
+ } else {
517
+ // Default
518
+ outputStrides = c10::fmap<ExprHandle>(make_contiguous_strides (outputShape));
519
+ }
480
520
481
521
std::vector<ArgValue> argInputs;
482
522
if (op == prim::ConstantChunk) {
@@ -521,12 +561,14 @@ Tensor TensorExprKernel::computeValue(const torch::jit::Value* v) {
521
561
}
522
562
523
563
if (NNCLoweringFunction custom_lowering = getCustomLoweringFor (op)) {
524
- return custom_lowering (argInputs, outputShape, outputType, device_);
564
+ return custom_lowering (
565
+ argInputs, outputShape, outputStrides, outputType, device_);
525
566
}
526
567
if (v->node ()->maybeSchema ()) {
527
568
if (NNCLoweringFunction lowering =
528
569
getStandardLoweringFor (c10::toString (v->node ()->schema ()))) {
529
- return lowering (argInputs, outputShape, outputType, device_);
570
+ return lowering (
571
+ argInputs, outputShape, outputStrides, outputType, device_);
530
572
}
531
573
}
532
574
std::string msg = std::string (" Unhandled node kind (in computeValue): " ) +
@@ -995,28 +1037,53 @@ Tensor TensorExprKernel::bindInput(const torch::jit::Value* input) {
995
1037
auto const & outputs = input->owningGraph ()->outputs ();
996
1038
std::unordered_set<const Value*> outputs_set (outputs.begin (), outputs.end ());
997
1039
1040
+ auto is_concrete_cont = [](const torch::jit::Value* input) {
1041
+ if (input->isCompleteTensor ()) {
1042
+ return isContiguous (input);
1043
+ } else {
1044
+ return false ;
1045
+ }
1046
+ };
1047
+
1048
+ auto is_symbolic_cont = [](std::vector<torch::jit::StrideInput> desc) {
1049
+ if (desc.size () == 1 ) {
1050
+ return desc[0 ] == torch::jit::StrideInput::TENSOR_CONT;
1051
+ } else {
1052
+ return false ;
1053
+ }
1054
+ };
1055
+
998
1056
Tensor result (nullptr , nullptr );
999
1057
switch (t->kind ()) {
1000
1058
case TypeKind::TensorType: {
1001
1059
auto tt = input->type ()->cast <TensorType>();
1002
- bool contiguous_concrete_tensor =
1003
- (input->isCompleteTensor () && isContiguous (input));
1004
- bool contiguous_sym_tensor = false ;
1060
+ bool contiguous_concrete_tensor = is_concrete_cont (input);
1061
+ bool contiguous_symbolic_tensor = false ;
1005
1062
if (has_symbolic_shapes_) {
1006
1063
auto desc = getSymbolicInputStrideDesc (input);
1007
- contiguous_sym_tensor =
1008
- desc.size () == 1 && desc[0 ] == torch::jit::StrideInput::TENSOR_CONT;
1064
+ contiguous_symbolic_tensor = is_symbolic_cont (desc);
1009
1065
}
1010
1066
1067
+ // Get input size and strides
1068
+ auto size_handles = sizesFromSymbolicShape (tt->symbolic_sizes ());
1069
+ auto inputTensorStrides = getInputStrides (input, size_handles);
1070
+
1011
1071
// We don't need to copy the input if:
1012
1072
// 1) it is not an output AND
1013
1073
// 2) it is contiguous
1014
- bool contiguous = contiguous_concrete_tensor || contiguous_sym_tensor;
1074
+ bool contiguous =
1075
+ contiguous_concrete_tensor || contiguous_symbolic_tensor;
1015
1076
if (!outputs_set.count (input) && contiguous) {
1016
1077
BufHandle inBuffer (
1017
1078
" t" + input_name_map_[input],
1018
1079
sizesFromSymbolicShape (tt->symbolic_sizes ()),
1080
+ inputTensorStrides,
1019
1081
ToDtype (static_cast <ScalarType>(*tt->scalarType ())));
1082
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY (
1083
+ inBuffer.node ()->is_contiguous () ||
1084
+ inBuffer.node ()->is_channels_last_1d_contiguous () ||
1085
+ inBuffer.node ()->is_contiguous (at::MemoryFormat::ChannelsLast) ||
1086
+ inBuffer.node ()->is_contiguous (at::MemoryFormat::ChannelsLast3d));
1020
1087
bufs_.emplace (input, inBuffer.node ());
1021
1088
bufferArgs_.emplace_back (inBuffer);
1022
1089
break ;
@@ -1025,8 +1092,6 @@ Tensor TensorExprKernel::bindInput(const torch::jit::Value* input) {
1025
1092
// if the input isn't contiguous or is an output,
1026
1093
// write strided input into contiguous buffer that is
1027
1094
// then used in all further compute
1028
- auto size_handles = sizesFromSymbolicShape (tt->symbolic_sizes ());
1029
- auto inputTensorStrides = getInputStrides (input, size_handles);
1030
1095
ExprHandle flat_size = 1 ;
1031
1096
for (size_t i = 0 ; i < size_handles.size (); ++i) {
1032
1097
auto size = size_handles[i];
@@ -1168,11 +1233,11 @@ Tensor TensorExprKernel::convertSymbolicOutputToCorrectStrides(
1168
1233
" Ouput tensor has no corresponding bufs in the fuser." ));
1169
1234
BufPtr buf = bufs_.at (v);
1170
1235
// output is contiguous, no work to do
1171
- if ( tensorOutputStrideDesc_[v->offset ()] ==
1172
- torch::jit::StrideInput::TENSOR_CONT) {
1236
+ auto stride_desc = tensorOutputStrideDesc_[v->offset ()];
1237
+ if (stride_desc == torch::jit::StrideInput::TENSOR_CONT) {
1173
1238
return Tensor (buf, nullptr );
1174
- ;
1175
1239
}
1240
+
1176
1241
TORCH_INTERNAL_ASSERT (
1177
1242
tensorOutputStrideDesc_[v->offset ()] ==
1178
1243
torch::jit::StrideInput::TENSOR_CONT_CHANNELS_LAST);
0 commit comments