Skip to content

Commit fc24d06

Browse files
zou3519facebook-github-bot
authored andcommitted
Tensor.contiguous, Tensor.is_contiguous batch rule (pytorch#47621)
Summary: Pull Request resolved: pytorch#47621 Followup to pytorch#47365. is_contiguous on BatchedTensorImpl is implemented as: - Whenever one creates a BatchedTensorImpl, we cache the strides of the per-examples, just like how we cache the sizes of the per-examples. - With the cached strides, we use TensorImpl::refresh_contiguous() to compute if the tensor is contiguous or not. - is_contiguous checks the `is_contiguous_` flag that refresh_contiguous() populates. Both contiguous and is_contiguous only support torch.contiguous_format. I'm not sure what the semantics should be for other memory formats; they are also rank dependent (e.g., channels_last tensor must have 4 dimensions) which makes this a bit tricky. Test Plan: - new tests Reviewed By: Chillee, anjali411 Differential Revision: D24840975 Pulled By: zou3519 fbshipit-source-id: 4d86dbf11e2eec45f3f08300ae3f2d79615bb99d
1 parent 6c815c7 commit fc24d06

File tree

4 files changed

+99
-2
lines changed

4 files changed

+99
-2
lines changed

Diff for: aten/src/ATen/BatchedTensorImpl.cpp

+11-1
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,18 @@ BatchedTensorImpl::BatchedTensorImpl(Tensor value, BatchDims bdims)
1919

2020
const auto public_dims = value_.dim() - bdims_.size();
2121
const auto value_sizes = value_.sizes();
22+
const auto value_strides = value_.strides();
2223
sizes_.clear();
2324
sizes_.reserve(public_dims);
25+
strides_.clear();
26+
strides_.reserve(public_dims);
2427
for (int64_t dim = 0; dim < public_dims; dim++) {
2528
auto actual_dim = actualDim(dim, /*wrap_dim=*/false);
2629
sizes_.push_back(value_sizes.at(actual_dim));
30+
strides_.push_back(value_strides.at(actual_dim));
2731
}
2832
refresh_numel();
33+
refresh_contiguous();
2934
}
3035

3136
int64_t BatchedTensorImpl::actualDim(int64_t dim, bool wrap_dim) const {
@@ -77,9 +82,14 @@ IntArrayRef BatchedTensorImpl::strides() const {
7782
int64_t BatchedTensorImpl::stride(int64_t d) const {
7883
TORCH_CHECK(false, "NYI: Getting tensor strides inside of vmap");
7984
}
85+
8086
bool BatchedTensorImpl::is_contiguous(at::MemoryFormat memory_format) const {
81-
TORCH_CHECK(false, "NYI: querying is_contiguous inside of vmap");
87+
TORCH_CHECK(memory_format == MemoryFormat::Contiguous,
88+
"NYI: querying is_contiguous inside of vmap for memory_format ",
89+
"other than torch.contiguous_format");
90+
return is_contiguous_;
8291
}
92+
8393
const Storage& BatchedTensorImpl::storage() const {
8494
TORCH_CHECK(false, "Due to limitations, we cannot access the storage() of a tensor from inside of vmap.");
8595
}

Diff for: aten/src/ATen/BatchingRegistrations.cpp

+11
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,15 @@ Tensor unfold_batching_rule(const Tensor& self, int64_t dim, int64_t size, int64
338338
return self_physical.newLogicalFromPhysical(result);
339339
}
340340

341+
Tensor contiguous_batching_rule(const Tensor& self, MemoryFormat memory_format) {
342+
TORCH_CHECK(memory_format == MemoryFormat::Contiguous,
343+
"NYI: Tensor.contiguous(...) inside of vmap for memory_format other ",
344+
"than torch.contiguous_format");
345+
auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
346+
auto result = physical_view.tensor().contiguous(memory_format);
347+
return physical_view.newLogicalFromPhysical(result);
348+
}
349+
341350
Tensor view_batching_rule(const Tensor& self, IntArrayRef size) {
342351
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
343352
auto size_physical = self_physical.getPhysicalShape(size);
@@ -1050,6 +1059,8 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
10501059
m.impl_UNBOXED("new_empty", new_empty_batching_rule);
10511060
m.impl_UNBOXED("new_empty_strided", new_empty_strided_batching_rule);
10521061
m.impl("new_zeros", new_zeros_batching_rule);
1062+
1063+
m.impl("contiguous", contiguous_batching_rule);
10531064
}
10541065

10551066
} // namespace at

Diff for: aten/src/ATen/test/vmap_test.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ TEST(VmapTest, TestBatchedTensor) {
1515
ASSERT_EQ(x.sizes(), expected_size);
1616
ASSERT_EQ(x.dim(), 2);
1717
ASSERT_EQ(x.numel(), 8);
18+
ASSERT_EQ(x.is_contiguous(), false);
1819
ASSERT_THROW(x.strides(), c10::Error);
19-
ASSERT_THROW(x.is_contiguous(), c10::Error);
2020
ASSERT_THROW(x.storage(), c10::Error);
2121
ASSERT_THROW(x.storage_offset(), c10::Error);
2222
}

Diff for: test/test_vmap.py

+76
Original file line numberDiff line numberDiff line change
@@ -1265,6 +1265,26 @@ def get(shape):
12651265
result = vmap(op)(real_tensor)
12661266
self.assertEqual(result.data_ptr(), real_tensor.data_ptr())
12671267

1268+
def test_contiguous(self):
1269+
op = Tensor.contiguous
1270+
1271+
self._test_unary(op, TensorFactory.randn, 'cpu')
1272+
1273+
# check that contiguous returns the original tensor if the per-examples
1274+
# are already contiguous
1275+
B0 = 3
1276+
x = torch.randn(B0, 2, 5, 7)
1277+
x = x.movedim(0, 2)
1278+
result = vmap(Tensor.contiguous, in_dims=2, out_dims=2)(x)
1279+
self.assertTrue(result is x)
1280+
1281+
msg = 'NYI: querying is_contiguous inside of vmap for memory_format'
1282+
tensor = torch.randn(B0, 3)
1283+
with self.assertRaisesRegex(RuntimeError, msg):
1284+
vmap(functools.partial(op, memory_format=torch.channels_last))(tensor)
1285+
with self.assertRaisesRegex(RuntimeError, msg):
1286+
vmap(functools.partial(op, memory_format=torch.channels_last_3d))(tensor)
1287+
12681288
def test_chunk(self):
12691289
test = self._vmap_view_test
12701290
op = torch.chunk
@@ -1432,6 +1452,62 @@ def foo(x):
14321452
self.assertEqual(vmap(foo)(ctensor), torch.tensor([1, 1, 1]))
14331453
self.assertEqual(vmap(foo)(tensor), torch.tensor([0, 0, 0]))
14341454

1455+
def test_is_contiguous(self):
1456+
def foo(x):
1457+
if x.is_contiguous():
1458+
return torch.tensor(1.)
1459+
else:
1460+
return torch.tensor(0.)
1461+
1462+
B0, B1 = 3, 5
1463+
1464+
# Single batch dim
1465+
contig = torch.randn(B0, 2, 7)
1466+
self.assertEqual(vmap(foo)(contig), torch.ones(B0))
1467+
1468+
noncontig = torch.randn(2, B0, 7)
1469+
self.assertEqual(vmap(foo, in_dims=1)(noncontig), torch.zeros(B0))
1470+
1471+
noncontig = torch.randn(2, B0, 7).movedim(1, 0)
1472+
self.assertEqual(vmap(foo)(noncontig), torch.zeros(B0))
1473+
1474+
noncontig = torch.randn(2, 7, B0)
1475+
self.assertEqual(vmap(foo, in_dims=2)(noncontig), torch.zeros(B0))
1476+
1477+
# Multiple batch dims
1478+
contig = torch.randn(B0, B1, 3)
1479+
self.assertEqual(vmap(vmap(foo))(contig), torch.ones(B0, B1))
1480+
1481+
contig = torch.randn(B1, B0, 3)
1482+
self.assertEqual(vmap(vmap(foo), in_dims=1)(contig), torch.ones(B0, B1))
1483+
1484+
contig = torch.randn(B1, B0, 3).movedim(0, 1)
1485+
self.assertEqual(vmap(vmap(foo))(contig), torch.ones(B0, B1))
1486+
1487+
noncontig = torch.randn(B0, 3, B1)
1488+
self.assertEqual(vmap(vmap(foo, in_dims=1))(noncontig), torch.zeros(B0, B1))
1489+
1490+
# is_contiguous on empty tensor is True
1491+
def bar(x):
1492+
assert x.is_contiguous()
1493+
return x
1494+
1495+
vmap(bar)(torch.randn(B0, 0, 3))
1496+
vmap(bar, in_dims=1)(torch.randn(0, B0, 3))
1497+
vmap(bar)(torch.randn(B0, 0, 3).transpose(-1, -2))
1498+
1499+
# is_contiguous with other memory formats
1500+
def baz(x, memory_format):
1501+
x.is_contiguous(memory_format=memory_format)
1502+
return x
1503+
1504+
msg = 'NYI: querying is_contiguous inside of vmap for memory_format'
1505+
tensor = torch.randn(B0, 2, 7, 3)
1506+
with self.assertRaisesRegex(RuntimeError, msg):
1507+
vmap(functools.partial(baz, memory_format=torch.channels_last))(tensor)
1508+
with self.assertRaisesRegex(RuntimeError, msg):
1509+
vmap(functools.partial(baz, memory_format=torch.channels_last_3d))(tensor)
1510+
14351511
def test_movedim(self):
14361512
op = torch.movedim
14371513
test = self._vmap_view_test

0 commit comments

Comments
 (0)