Skip to content

Commit 320e5a8

Browse files
malfetpytorchmergebot
authored andcommitted
Revert D34808051: [tensorexpr] Enabled aten::stack in the fuser pass with static shapes
Test Plan: revert-hammer Differential Revision: D34808051 Original commit changeset: 213e2ffdf87f Original Phabricator Diff: D34808051 fbshipit-source-id: b618daeb346f784e8ab9525040edcb4a30a39613 (cherry picked from commit e47b973)
1 parent ec6f767 commit 320e5a8

File tree

4 files changed

+6
-69
lines changed

4 files changed

+6
-69
lines changed

test/cpp/tensorexpr/test_te_fuser_pass.cpp

-19
Original file line numberDiff line numberDiff line change
@@ -317,25 +317,6 @@ TEST(TEFuserPass, FuserPass_IgnoreUnknownShapeAtStart) {
317317
testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
318318
}
319319

320-
TEST(TEFuserPass, FuserPass_Stack) {
321-
WithCPUFuser cf;
322-
const auto graph_string =
323-
R"IR(graph(%y.1 : Float(5, 3, 3, 6, strides=[54, 18, 6, 1], requires_grad=0, device=cpu),
324-
%x.1 : Float(5, 3, 3, 6, strides=[54, 18, 6, 1], requires_grad=0, device=cpu)):
325-
%1 : int = prim::Constant[value=2]()
326-
%9 : Float(5, 3, 3, 6, strides=[54, 18, 6, 1], requires_grad=0, device=cpu) = aten::tanh(%x.1)
327-
%7 : Float(5, 3, 3, 6, strides=[54, 18, 6, 1], requires_grad=0, device=cpu) = aten::tanh(%y.1)
328-
%5 : Tensor[] = prim::ListConstruct(%9, %7)
329-
%z.2 : Float(5, 3, 2, 3, 6, strides=[108, 36, 18, 6, 1], requires_grad=0, device=cpu) = aten::stack(%5, %1)
330-
return (%z.2)
331-
)IR";
332-
auto g = std::make_shared<Graph>();
333-
torch::jit::parseIR(graph_string, g.get());
334-
g->lint();
335-
FuseTensorExprs(g, /* min_group_size= */ 2);
336-
testing::FileCheck().check("prim::TensorExprGroup")->run(*g);
337-
}
338-
339320
TEST(TEFuserPass, FuserPass_Where) {
340321
WithCPUFuser cf;
341322
const auto graph_string = R"IR(

test/test_jit_fuser_te.py

-17
Original file line numberDiff line numberDiff line change
@@ -739,22 +739,6 @@ def foo(hx, cx):
739739
# XXX: TE fuser can handle concats in a fusion group.
740740
# FileCheck().check("FusedConcat").check_next("return").run(str(graph))
741741

742-
def test_stack(self):
743-
# "aten::stack fusion is not enabled yet with dynamic shapes"
744-
if self.dynamic_shapes:
745-
return True
746-
with set_fusion_group_inlining(True):
747-
for device in self.devices:
748-
hx = torch.randn(3, 20, dtype=torch.float, device=device)
749-
cx = torch.randn(3, 20, dtype=torch.float, device=device)
750-
751-
def foo(hx, cx):
752-
return torch.stack((hx + cx, hx - cx))
753-
754-
ge = self.checkTrace(foo, (hx, cx))
755-
graph = ge.graph_for(hx, cx)
756-
self.assertAllFused(graph)
757-
758742
def test_remove_output_used_only_in_size(self):
759743
for device in self.devices:
760744
def test_fuse(a, b):
@@ -1797,7 +1781,6 @@ def apply(fn):
17971781
devices = self.devices
17981782
list_ops = [
17991783
torch.cat,
1800-
torch.stack
18011784
]
18021785
for dtype, op, device in product(self.dtypes, list_ops, devices):
18031786
if dtype in [torch.float16, torch.bfloat16] and device == "cpu":

test/test_tensorexpr_pybind.py

-25
Original file line numberDiff line numberDiff line change
@@ -390,31 +390,6 @@ def f(a):
390390
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
391391
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
392392

393-
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
394-
def test_kernel_with_stack(self):
395-
def f(a, b):
396-
return torch.stack((a, b), dim=1)
397-
398-
device = "cpu"
399-
x = torch.rand((3, 5), device=device)
400-
y = torch.rand((3, 5), device=device)
401-
graph_str = """
402-
graph(%x.1 : Float(3, 5, strides=[5, 1], requires_grad=0, device=cpu),
403-
%y.1 : Float(3, 5, strides=[5, 1], requires_grad=0, device=cpu)):
404-
%1 : int = prim::Constant[value=1]()
405-
%5 : Tensor[] = prim::ListConstruct(%x.1, %y.1)
406-
%z.2 : Float(3, 2, 5, strides=[10, 5, 1], requires_grad=0, device=cpu) = aten::stack(%5, %1) # local/stack.py:39:12
407-
return (%z.2)
408-
"""
409-
graph = torch._C.parse_ir(graph_str)
410-
411-
kernel = te.TensorExprKernel(graph)
412-
res1 = kernel.run((x, y))
413-
res2 = kernel.fallback((x, y))
414-
correct = f(x, y)
415-
np.testing.assert_allclose(res1.numpy(), correct.numpy(), atol=2e-3)
416-
np.testing.assert_allclose(res2.numpy(), correct.numpy(), atol=2e-3)
417-
418393
@unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
419394
def test_alloc_in_loop(self):
420395
a, tmp, b = [

torch/csrc/jit/passes/tensorexpr_fuser.cpp

+6-8
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ bool isSupported(Node* node) {
9494
};
9595
static const OperatorSet supported_misc_set{
9696
"aten::cat(Tensor[] tensors, int dim=0) -> Tensor",
97-
"aten::stack(Tensor[] tensors, int dim=0) -> Tensor",
9897
"aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)",
9998
};
10099
// clang-format on
@@ -772,7 +771,7 @@ class TensorExprFuser {
772771

773772
std::vector<Node*> nodes_to_merge = {to_merge};
774773

775-
if (to_merge->kind() == aten::cat || to_merge->kind() == aten::stack) {
774+
if (to_merge->kind() == aten::cat) {
776775
Node* listconstruct = to_merge->input(0)->node();
777776
nodes_to_merge.push_back(listconstruct);
778777
}
@@ -1054,6 +1053,7 @@ class TensorExprFuser {
10541053
REQ(isFusableOnDevice(node));
10551054
REQ(operators_not_to_fuse.find(node->kind()) ==
10561055
operators_not_to_fuse.end());
1056+
10571057
for (Value* input : node->inputs()) {
10581058
if (auto const& tt = input->type()->cast<TensorType>()) {
10591059
auto st = tt->scalarType();
@@ -1066,7 +1066,7 @@ class TensorExprFuser {
10661066
}
10671067
}
10681068
}
1069-
if (node->kind() == aten::cat || node->kind() == aten::stack) {
1069+
if (node->kind() == aten::cat) {
10701070
REQ(node->input(0)->node()->kind() == prim::ListConstruct);
10711071
REQ(node->input(0)->uses().size() == 1);
10721072
REQ(node->input(1)->node()->kind() == prim::Constant);
@@ -1120,8 +1120,7 @@ class TensorExprFuser {
11201120
REQ(nInputs <= subgraphArgLimit);
11211121

11221122
// Device checks
1123-
if (consumer->kind() != aten::cat && producer->kind() != aten::cat &&
1124-
consumer->kind() != aten::stack && producer->kind() != aten::stack) {
1123+
if (consumer->kind() != aten::cat && producer->kind() != aten::cat) {
11251124
// aten::cat needs a special handling because it takes a Tensor[] as its
11261125
// input We deal with that in the code below.
11271126
auto consumer_device = tensorexpr::pickDeviceType(consumer->inputs());
@@ -1155,7 +1154,7 @@ class TensorExprFuser {
11551154
REQ(producer->kind() != prim::Constant);
11561155
}
11571156

1158-
if (producer->kind() == aten::cat || producer->kind() == aten::stack) {
1157+
if (producer->kind() == aten::cat) {
11591158
REQ(producer->input(0)->node()->kind() == prim::ListConstruct);
11601159
REQ(producer->input(0)->uses().size() == 1);
11611160
REQ(producer->input(1)->node()->kind() == prim::Constant);
@@ -1173,8 +1172,7 @@ class TensorExprFuser {
11731172
REQ(isFusableOnDevice(input->node()));
11741173
}
11751174
REQ((nInputs + listConstruct->inputs().size()) <= subgraphArgLimit);
1176-
} else if (
1177-
consumer->kind() == aten::cat || consumer->kind() == aten::stack) {
1175+
} else if (consumer->kind() == aten::cat) {
11781176
REQ(consumer->input(0)->node()->kind() == prim::ListConstruct);
11791177
REQ(consumer->input(0)->uses().size() == 1);
11801178
REQ(consumer->input(1)->node()->kind() == prim::Constant);

0 commit comments

Comments
 (0)