|
1 | 1 | /*******************************************************************************
|
2 |
| -* Copyright 2020 Intel Corporation |
| 2 | +* Copyright 2021 Intel Corporation |
3 | 3 | *
|
4 | 4 | * Licensed under the Apache License, Version 2.0 (the "License");
|
5 | 5 | * you may not use this file except in compliance with the License.
|
|
14 | 14 | * limitations under the License.
|
15 | 15 | *******************************************************************************/
|
16 | 16 |
|
17 |
| -#pragma once |
18 |
| - |
19 |
| -#ifdef ONEDAL_DATA_PARALLEL |
20 |
| -#include <daal/include/data_management/data/internal/numeric_table_sycl_homogen.h> |
21 |
| -#endif |
22 |
| - |
23 |
| -#include "oneapi/dal/backend/memory.hpp" |
24 |
| -#include "oneapi/dal/table/detail/table_builder.hpp" |
25 |
| -#include "oneapi/dal/table/backend/interop/sycl_table_adapter.hpp" |
26 |
| -#include "oneapi/dal/table/backend/interop/host_homogen_table_adapter.hpp" |
27 |
| -#include "oneapi/dal/table/backend/interop/host_soa_table_adapter.hpp" |
28 |
| -#include "oneapi/dal/table/backend/interop/host_csr_table_adapter.hpp" |
29 |
| -#include "oneapi/dal/backend/interop/csr_block_owner.hpp" |
30 |
| - |
31 |
| -namespace oneapi::dal::backend::interop { |
32 |
| - |
33 |
| -template <typename Data> |
34 |
| -inline auto allocate_daal_homogen_table(std::int64_t row_count, std::int64_t column_count) { |
35 |
| - return daal::data_management::HomogenNumericTable<Data>::create( |
36 |
| - dal::detail::integral_cast<std::size_t>(column_count), |
37 |
| - dal::detail::integral_cast<std::size_t>(row_count), |
38 |
| - daal::data_management::NumericTable::doAllocate); |
39 |
| -} |
40 |
| - |
41 |
| -template <typename Data> |
42 |
| -inline auto empty_daal_homogen_table(std::int64_t column_count) { |
43 |
| - return daal::data_management::HomogenNumericTable<Data>::create( |
44 |
| - dal::detail::integral_cast<std::size_t>(column_count), |
45 |
| - dal::detail::integral_cast<std::size_t>(0), |
46 |
| - daal::data_management::NumericTable::notAllocate); |
47 |
| -} |
48 |
| - |
49 |
| -template <typename Data> |
50 |
| -inline auto convert_to_daal_homogen_table(array<Data>& data, |
51 |
| - std::int64_t row_count, |
52 |
| - std::int64_t column_count, |
53 |
| - bool allow_copy = false) { |
54 |
| - if (!data.get_count()) { |
55 |
| - return daal::services::SharedPtr<daal::data_management::HomogenNumericTable<Data>>(); |
56 |
| - } |
57 |
| - |
58 |
| - if (allow_copy) { |
59 |
| - data.need_mutable_data(); |
60 |
| - } |
61 |
| - |
62 |
| - ONEDAL_ASSERT(data.has_mutable_data()); |
63 |
| - const auto daal_data = |
64 |
| - daal::services::SharedPtr<Data>(data.get_mutable_data(), daal_object_owner{ data }); |
65 |
| - |
66 |
| - return daal::data_management::HomogenNumericTable<Data>::create( |
67 |
| - daal_data, |
68 |
| - dal::detail::integral_cast<std::size_t>(column_count), |
69 |
| - dal::detail::integral_cast<std::size_t>(row_count)); |
70 |
| -} |
71 |
| - |
72 |
| -template <typename Data> |
73 |
| -inline daal::data_management::NumericTablePtr copy_to_daal_homogen_table(const table& table) { |
74 |
| - // TODO: Preserve information about features |
75 |
| - const bool allow_copy = true; |
76 |
| - auto rows = row_accessor<const Data>{ table }.pull(); |
77 |
| - return convert_to_daal_homogen_table(rows, |
78 |
| - table.get_row_count(), |
79 |
| - table.get_column_count(), |
80 |
| - allow_copy); |
81 |
| -} |
82 |
| - |
83 |
| -template <typename Data> |
84 |
| -inline table convert_from_daal_homogen_table(const daal::data_management::NumericTablePtr& nt) { |
85 |
| - if (nt->getNumberOfRows() == 0) { |
86 |
| - return table{}; |
87 |
| - } |
88 |
| - daal::data_management::BlockDescriptor<Data> block; |
89 |
| - const std::int64_t row_count = dal::detail::integral_cast<std::int64_t>(nt->getNumberOfRows()); |
90 |
| - const std::int64_t column_count = |
91 |
| - dal::detail::integral_cast<std::int64_t>(nt->getNumberOfColumns()); |
92 |
| - |
93 |
| - nt->getBlockOfRows(0, row_count, daal::data_management::readOnly, block); |
94 |
| - Data* data = block.getBlockPtr(); |
95 |
| - array<Data> arr(data, row_count * column_count, [nt, block](Data* p) mutable { |
96 |
| - nt->releaseBlockOfRows(block); |
97 |
| - }); |
98 |
| - return detail::homogen_table_builder{}.reset(arr, row_count, column_count).build(); |
99 |
| -} |
100 |
| - |
101 |
| -inline daal::data_management::NumericTablePtr wrap_by_host_homogen_adapter( |
102 |
| - const homogen_table& table) { |
103 |
| - const auto& dtype = table.get_metadata().get_data_type(0); |
104 |
| - |
105 |
| - switch (dtype) { |
106 |
| - case data_type::float32: return host_homogen_table_adapter<float>::create(table); |
107 |
| - case data_type::float64: return host_homogen_table_adapter<double>::create(table); |
108 |
| - case data_type::int32: return host_homogen_table_adapter<std::int32_t>::create(table); |
109 |
| - default: return daal::data_management::NumericTablePtr(); |
110 |
| - } |
111 |
| -} |
112 |
| - |
113 |
| -inline daal::data_management::NumericTablePtr wrap_by_host_soa_adapter(const homogen_table& table) { |
114 |
| - const auto& dtype = table.get_metadata().get_data_type(0); |
115 |
| - |
116 |
| - switch (dtype) { |
117 |
| - case data_type::float32: return host_soa_table_adapter::create<float>(table); |
118 |
| - case data_type::float64: return host_soa_table_adapter::create<double>(table); |
119 |
| - case data_type::int32: return host_soa_table_adapter::create<std::int32_t>(table); |
120 |
| - default: return daal::data_management::NumericTablePtr(); |
121 |
| - } |
122 |
| -} |
123 |
| - |
124 |
| -template <typename Data> |
125 |
| -inline daal::data_management::NumericTablePtr convert_to_daal_table(const homogen_table& table) { |
126 |
| - if (table.get_data_layout() == data_layout::row_major) { |
127 |
| - if (auto wrapper = wrap_by_host_homogen_adapter(table)) { |
128 |
| - return wrapper; |
129 |
| - } |
130 |
| - } |
131 |
| - else if (table.get_data_layout() == data_layout::column_major) { |
132 |
| - if (auto wrapper = wrap_by_host_soa_adapter(table)) { |
133 |
| - return wrapper; |
134 |
| - } |
135 |
| - } |
136 |
| - return copy_to_daal_homogen_table<Data>(table); |
137 |
| -} |
138 |
| - |
139 |
| -template <typename T> |
140 |
| -inline auto convert_to_daal_csr_table(array<T>& data, |
141 |
| - array<std::int64_t>& column_indices, |
142 |
| - array<std::int64_t>& row_indices, |
143 |
| - std::int64_t row_count, |
144 |
| - std::int64_t column_count, |
145 |
| - bool allow_copy = false) { |
146 |
| - ONEDAL_ASSERT(data.get_count() == column_indices.get_count()); |
147 |
| - ONEDAL_ASSERT(row_indices.get_count() == row_count + 1); |
148 |
| - |
149 |
| - if (!data.get_count() || !column_indices.get_count() || !row_indices.get_count()) { |
150 |
| - return daal::services::SharedPtr<daal::data_management::CSRNumericTable>(); |
151 |
| - } |
152 |
| - |
153 |
| - if (allow_copy) { |
154 |
| - data.need_mutable_data(); |
155 |
| - column_indices.need_mutable_data(); |
156 |
| - row_indices.need_mutable_data(); |
157 |
| - } |
158 |
| - |
159 |
| - ONEDAL_ASSERT(data.has_mutable_data()); |
160 |
| - ONEDAL_ASSERT(column_indices.has_mutable_data()); |
161 |
| - ONEDAL_ASSERT(row_indices.has_mutable_data()); |
162 |
| - |
163 |
| - const auto daal_data = |
164 |
| - daal::services::SharedPtr<T>(data.get_mutable_data(), daal_object_owner{ data }); |
165 |
| - ONEDAL_ASSERT(sizeof(std::size_t) == sizeof(std::int64_t)); |
166 |
| - const auto daal_column_indices = daal::services::SharedPtr<std::size_t>( |
167 |
| - reinterpret_cast<std::size_t*>(column_indices.get_mutable_data()), |
168 |
| - daal_object_owner{ column_indices }); |
169 |
| - const auto daal_row_indices = daal::services::SharedPtr<std::size_t>( |
170 |
| - reinterpret_cast<std::size_t*>(row_indices.get_mutable_data()), |
171 |
| - daal_object_owner{ row_indices }); |
172 |
| - |
173 |
| - return daal::data_management::CSRNumericTable::create( |
174 |
| - daal_data, |
175 |
| - daal_column_indices, |
176 |
| - daal_row_indices, |
177 |
| - dal::detail::integral_cast<std::size_t>(column_count), |
178 |
| - dal::detail::integral_cast<std::size_t>(row_count)); |
179 |
| -} |
180 |
| - |
181 |
| -template <typename Float> |
182 |
| -inline daal::data_management::CSRNumericTablePtr copy_to_daal_csr_table(const csr_table& table) { |
183 |
| - const bool allow_copy = true; |
184 |
| - auto [data, column_indices, row_offsets] = csr_accessor<const Float>{ table }.pull(); |
185 |
| - return convert_to_daal_csr_table(data, |
186 |
| - column_indices, |
187 |
| - row_offsets, |
188 |
| - table.get_row_count(), |
189 |
| - table.get_column_count(), |
190 |
| - allow_copy); |
191 |
| -} |
192 |
| - |
193 |
| -template <typename T> |
194 |
| -inline table convert_from_daal_csr_table(const daal::data_management::NumericTablePtr& nt) { |
195 |
| - auto block_owner = std::make_shared<csr_block_owner<T>>(csr_block_owner<T>{ nt }); |
196 |
| - |
197 |
| - ONEDAL_ASSERT(sizeof(std::size_t) == sizeof(std::int64_t)); |
198 |
| - |
199 |
| - return csr_table{ |
200 |
| - array<T>{ block_owner->get_data(), |
201 |
| - block_owner->get_element_count(), |
202 |
| - [block_owner](const T* p) {} }, |
203 |
| - array<std::int64_t>{ reinterpret_cast<std::int64_t*>(block_owner->get_column_indices()), |
204 |
| - block_owner->get_element_count(), |
205 |
| - [block_owner](const std::int64_t* p) {} }, |
206 |
| - array<std::int64_t>{ reinterpret_cast<std::int64_t*>(block_owner->get_row_indices()), |
207 |
| - block_owner->get_row_count() + 1, |
208 |
| - [block_owner](const std::int64_t* p) {} }, |
209 |
| - block_owner->get_column_count() |
210 |
| - }; |
211 |
| -} |
212 |
| - |
213 |
| -inline daal::data_management::CSRNumericTablePtr wrap_by_host_csr_adapter(const csr_table& table) { |
214 |
| - const auto& dtype = table.get_metadata().get_data_type(0); |
215 |
| - |
216 |
| - switch (dtype) { |
217 |
| - case data_type::float32: return host_csr_table_adapter<float>::create(table); |
218 |
| - case data_type::float64: return host_csr_table_adapter<double>::create(table); |
219 |
| - case data_type::int32: return host_csr_table_adapter<std::int32_t>::create(table); |
220 |
| - default: return daal::data_management::CSRNumericTablePtr(); |
221 |
| - } |
222 |
| -} |
223 |
| - |
224 |
| -template <typename Float> |
225 |
| -inline daal::data_management::CSRNumericTablePtr convert_to_daal_table(const csr_table& table) { |
226 |
| - auto wrapper = wrap_by_host_csr_adapter(table); |
227 |
| - if (!wrapper) { |
228 |
| - return copy_to_daal_csr_table<Float>(table); |
229 |
| - } |
230 |
| - else { |
231 |
| - return wrapper; |
232 |
| - } |
233 |
| -} |
234 |
| - |
235 |
| -template <typename Data> |
236 |
| -inline daal::data_management::NumericTablePtr convert_to_daal_table(const table& table) { |
237 |
| - if (table.get_kind() == homogen_table::kind()) { |
238 |
| - const auto& homogen = static_cast<const homogen_table&>(table); |
239 |
| - return convert_to_daal_table<Data>(homogen); |
240 |
| - } |
241 |
| - else if (table.get_kind() == csr_table::kind()) { |
242 |
| - const auto& csr = static_cast<const csr_table&>(table); |
243 |
| - return convert_to_daal_table<Data>(csr); |
244 |
| - } |
245 |
| - else { |
246 |
| - return copy_to_daal_homogen_table<Data>(table); |
247 |
| - } |
248 |
| -} |
249 |
| - |
250 |
| -template <typename Data> |
251 |
| -inline table convert_from_daal_table(const daal::data_management::NumericTablePtr& nt) { |
252 |
| - if (nt->getDataLayout() == daal::data_management::NumericTableIface::StorageLayout::csrArray) { |
253 |
| - return convert_from_daal_csr_table<Data>(nt); |
254 |
| - } |
255 |
| - else { |
256 |
| - return convert_from_daal_homogen_table<Data>(nt); |
257 |
| - } |
258 |
| -} |
259 |
| - |
260 |
| -#ifdef ONEDAL_DATA_PARALLEL |
261 |
| -inline daal::data_management::NumericTablePtr convert_to_daal_table(const sycl::queue& queue, |
262 |
| - const table& table) { |
263 |
| - if (!table.has_data()) { |
264 |
| - return daal::data_management::NumericTablePtr{}; |
265 |
| - } |
266 |
| - return interop::sycl_table_adapter::create(queue, table); |
267 |
| -} |
268 |
| - |
269 |
| -template <typename Data> |
270 |
| -inline daal::data_management::NumericTablePtr convert_to_daal_table(const sycl::queue& queue, |
271 |
| - const array<Data>& data, |
272 |
| - std::int64_t row_count, |
273 |
| - std::int64_t column_count) { |
274 |
| - using daal::services::Status; |
275 |
| - using daal::services::SharedPtr; |
276 |
| - using daal::services::internal::Buffer; |
277 |
| - using daal::data_management::internal::SyclHomogenNumericTable; |
278 |
| - using dal::detail::integral_cast; |
279 |
| - |
280 |
| - ONEDAL_ASSERT(data.get_count() == row_count * column_count); |
281 |
| - ONEDAL_ASSERT(data.has_mutable_data()); |
282 |
| - ONEDAL_ASSERT(is_same_context(queue, data)); |
283 |
| - |
284 |
| - const SharedPtr<Data> data_shared{ data.get_mutable_data(), daal_object_owner{ data } }; |
285 |
| - |
286 |
| - Status status; |
287 |
| - const Buffer<Data> buffer{ data_shared, |
288 |
| - integral_cast<std::size_t>(data.get_count()), |
289 |
| - queue, |
290 |
| - status }; |
291 |
| - status_to_exception(status); |
292 |
| - |
293 |
| - const auto table = |
294 |
| - SyclHomogenNumericTable<Data>::create(buffer, |
295 |
| - integral_cast<std::size_t>(column_count), |
296 |
| - integral_cast<std::size_t>(row_count), |
297 |
| - &status); |
298 |
| - status_to_exception(status); |
299 |
| - |
300 |
| - return table; |
301 |
| -} |
302 |
| -#endif |
303 |
| - |
304 |
| -} // namespace oneapi::dal::backend::interop |
| 17 | +#include "oneapi/dal/table/backend/interop/table_conversion.hpp" |
0 commit comments