|
| 1 | +#include <gtest/gtest.h> |
| 2 | + |
| 3 | +#include <ATen/core/ivalue.h> |
| 4 | + |
| 5 | +#include <c10/util/flat_hash_map.h> |
| 6 | +#include <c10/util/irange.h> |
| 7 | +#include <c10/util/tempfile.h> |
| 8 | + |
| 9 | +#include <torch/torch.h> |
| 10 | + |
| 11 | +#include <test/cpp/api/support.h> |
| 12 | + |
| 13 | +#include <cstdio> |
| 14 | +#include <memory> |
| 15 | +#include <sstream> |
| 16 | +#include <string> |
| 17 | +#include <vector> |
| 18 | + |
| 19 | +using namespace torch::test; |
| 20 | +using namespace torch::nn; |
| 21 | +using namespace torch::optim; |
| 22 | + |
| 23 | +TEST(IValueTest, DeepcopyTensors) { |
| 24 | + torch::Tensor t0 = torch::randn({2, 3}); |
| 25 | + torch::Tensor t1 = torch::randn({3, 4}); |
| 26 | + torch::Tensor t2 = t0.detach(); |
| 27 | + torch::Tensor t3 = t0; |
| 28 | + torch::Tensor t4 = t1.as_strided({2, 3}, {3, 1}, 2); |
| 29 | + std::vector<torch::Tensor> tensor_vector = {t0, t1, t2, t3, t4}; |
| 30 | + c10::List<torch::Tensor> tensor_list(tensor_vector); |
| 31 | + torch::IValue tensor_list_ivalue(tensor_list); |
| 32 | + |
| 33 | + c10::IValue::CompIdentityIValues ivalue_compare; |
| 34 | + |
| 35 | + // Make sure our setup configuration is correct |
| 36 | + ASSERT_TRUE(ivalue_compare(tensor_list[0].get(), tensor_list[3].get())); |
| 37 | + ASSERT_FALSE(ivalue_compare(tensor_list[0].get(), tensor_list[1].get())); |
| 38 | + ASSERT_FALSE(ivalue_compare(tensor_list[0].get(), tensor_list[2].get())); |
| 39 | + ASSERT_FALSE(ivalue_compare(tensor_list[1].get(), tensor_list[4].get())); |
| 40 | + ASSERT_TRUE(tensor_list[0].get().isAliasOf(tensor_list[2].get())); |
| 41 | + |
| 42 | + c10::IValue copied_ivalue = tensor_list_ivalue.deepcopy(); |
| 43 | + c10::List<torch::IValue> copied_list = copied_ivalue.toList(); |
| 44 | + |
| 45 | + // Make sure our setup configuration is correct |
| 46 | + ASSERT_TRUE(ivalue_compare(copied_list[0].get(), copied_list[3].get())); |
| 47 | + ASSERT_FALSE(ivalue_compare(copied_list[0].get(), copied_list[1].get())); |
| 48 | + ASSERT_FALSE(ivalue_compare(copied_list[0].get(), copied_list[2].get())); |
| 49 | + ASSERT_FALSE(ivalue_compare(copied_list[1].get(), copied_list[4].get())); |
| 50 | + // NOTE: this is actually incorrect. Ideally, these _should_ be aliases. |
| 51 | + ASSERT_FALSE(copied_list[0].get().isAliasOf(copied_list[2].get())); |
| 52 | + |
| 53 | + ASSERT_TRUE(copied_list[0].get().toTensor().allclose( |
| 54 | + tensor_list[0].get().toTensor())); |
| 55 | + ASSERT_TRUE(copied_list[1].get().toTensor().allclose( |
| 56 | + tensor_list[1].get().toTensor())); |
| 57 | + ASSERT_TRUE(copied_list[2].get().toTensor().allclose( |
| 58 | + tensor_list[2].get().toTensor())); |
| 59 | + ASSERT_TRUE(copied_list[3].get().toTensor().allclose( |
| 60 | + tensor_list[3].get().toTensor())); |
| 61 | + ASSERT_TRUE(copied_list[4].get().toTensor().allclose( |
| 62 | + tensor_list[4].get().toTensor())); |
| 63 | +} |
0 commit comments