Skip to content

Commit abfd293

Browse files
bdhirshpytorchmergebot
authored andcommitted
functionalization: fix x.is_contiguous(channels_last) (pytorch#94195)
Pull Request resolved: pytorch#94195 Approved by: https://github.com/ezyang
1 parent aba4fb9 commit abfd293

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

aten/src/ATen/FunctionalTensorWrapper.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ int64_t FunctionalTensorWrapper::numel_custom() const {
343343
return value_.unsafeGetTensorImpl()->numel();
344344
}
345345
bool FunctionalTensorWrapper::is_contiguous_custom(at::MemoryFormat memory_format) const {
346-
return value_.unsafeGetTensorImpl()->is_contiguous();
346+
return value_.unsafeGetTensorImpl()->is_contiguous(memory_format);
347347
}
348348
c10::SymIntArrayRef FunctionalTensorWrapper::sym_sizes_custom() const {
349349
return value_.unsafeGetTensorImpl()->sym_sizes();

test/test_functionalization.py

+15
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,21 @@ def forward(self, arg0_1):
583583
return diagonal_scatter
584584
""")
585585

586+
def test_channels_last_contiguous(self):
587+
def f(x):
588+
return x.contiguous(memory_format=torch.channels_last)
589+
tmp = torch.ones(2)
590+
y = x.diagonal()
591+
y.add_(tmp)
592+
return x
593+
x = torch.randn(4, 8, 8, 3).permute(0, 3, 1, 2)
594+
self.assert_functionalization(f, x)
595+
logs = self.get_logs(f, x).strip()
596+
# There should be no clone in the graph
597+
self.assertExpectedInline(logs, """\
598+
def forward(self, arg0_1):
599+
return arg0_1""")
600+
586601
def test_split(self):
587602
def f(x):
588603
# test: view ops that return multiple tensors (split)

0 commit comments

Comments
 (0)