Skip to content

Commit 6d4ddba

Browse files
author
Nikita Kulikov
committed
One commit to rule them all
1 parent 64bb0fb commit 6d4ddba

25 files changed

+1789
-638
lines changed

cpp/oneapi/dal/backend/interop/data_conversion.hpp

+31-19
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ namespace oneapi::dal::backend::interop {
2424
// TODO: Remove using namespace
2525
using namespace daal::data_management;
2626

27-
features::IndexNumType getIndexNumType(data_type t) {
27+
inline features::IndexNumType getIndexNumType(data_type t) {
2828
switch (t) {
2929
case data_type::int32: return features::DAAL_INT32_S;
3030
case data_type::int64: return features::DAAL_INT64_S;
@@ -36,7 +36,19 @@ features::IndexNumType getIndexNumType(data_type t) {
3636
}
3737
}
3838

39-
internal::ConversionDataType getConversionDataType(data_type t) {
39+
inline data_type get_dal_data_type(features::IndexNumType t) {
40+
switch (t) {
41+
case features::DAAL_INT32_S: return data_type::int32;
42+
case features::DAAL_INT64_S: return data_type::int64;
43+
case features::DAAL_INT32_U: return data_type::uint32;
44+
case features::DAAL_INT64_U: return data_type::uint64;
45+
case features::DAAL_FLOAT32: return data_type::float32;
46+
case features::DAAL_FLOAT64: return data_type::float64;
47+
default: return data_type::float32;
48+
}
49+
}
50+
51+
inline internal::ConversionDataType getConversionDataType(data_type t) {
4052
switch (t) {
4153
case data_type::int32: return internal::DAAL_INT32;
4254
case data_type::float32: return internal::DAAL_SINGLE;
@@ -46,11 +58,11 @@ internal::ConversionDataType getConversionDataType(data_type t) {
4658
}
4759

4860
template <typename DownCast, typename UpCast, typename... Args>
49-
void daal_convert_dispatcher(data_type src_type,
50-
data_type dst_type,
51-
DownCast&& dcast,
52-
UpCast&& ucast,
53-
Args&&... args) {
61+
inline void daal_convert_dispatcher(data_type src_type,
62+
data_type dst_type,
63+
DownCast&& dcast,
64+
UpCast&& ucast,
65+
Args&&... args) {
5466
auto from_type = getIndexNumType(src_type);
5567
auto to_type = getConversionDataType(dst_type);
5668

@@ -74,11 +86,11 @@ void daal_convert_dispatcher(data_type src_type,
7486
}
7587
}
7688

77-
void daal_convert(const void* src,
78-
void* dst,
79-
data_type src_type,
80-
data_type dst_type,
81-
std::int64_t element_count) {
89+
inline void daal_convert(const void* src,
90+
void* dst,
91+
data_type src_type,
92+
data_type dst_type,
93+
std::int64_t element_count) {
8294
daal_convert_dispatcher(src_type,
8395
dst_type,
8496
internal::getVectorDownCast,
@@ -88,13 +100,13 @@ void daal_convert(const void* src,
88100
dst);
89101
}
90102

91-
void daal_convert(const void* src,
92-
void* dst,
93-
data_type src_type,
94-
data_type dst_type,
95-
std::int64_t src_stride,
96-
std::int64_t dst_stride,
97-
std::int64_t element_count) {
103+
inline void daal_convert(const void* src,
104+
void* dst,
105+
data_type src_type,
106+
data_type dst_type,
107+
std::int64_t src_stride,
108+
std::int64_t dst_stride,
109+
std::int64_t element_count) {
98110
daal_convert_dispatcher(src_type,
99111
dst_type,
100112
internal::getVectorStrideDownCast,
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2020 Intel Corporation
2+
* Copyright 2021 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -14,291 +14,4 @@
1414
* limitations under the License.
1515
*******************************************************************************/
1616

17-
#pragma once
18-
19-
#ifdef ONEDAL_DATA_PARALLEL
20-
#include <daal/include/data_management/data/internal/numeric_table_sycl_homogen.h>
21-
#endif
22-
23-
#include "oneapi/dal/backend/memory.hpp"
24-
#include "oneapi/dal/table/detail/table_builder.hpp"
25-
#include "oneapi/dal/table/backend/interop/sycl_table_adapter.hpp"
26-
#include "oneapi/dal/table/backend/interop/host_homogen_table_adapter.hpp"
27-
#include "oneapi/dal/table/backend/interop/host_soa_table_adapter.hpp"
28-
#include "oneapi/dal/table/backend/interop/host_csr_table_adapter.hpp"
29-
#include "oneapi/dal/backend/interop/csr_block_owner.hpp"
30-
31-
namespace oneapi::dal::backend::interop {
32-
33-
template <typename Data>
34-
inline auto allocate_daal_homogen_table(std::int64_t row_count, std::int64_t column_count) {
35-
return daal::data_management::HomogenNumericTable<Data>::create(
36-
dal::detail::integral_cast<std::size_t>(column_count),
37-
dal::detail::integral_cast<std::size_t>(row_count),
38-
daal::data_management::NumericTable::doAllocate);
39-
}
40-
41-
template <typename Data>
42-
inline auto empty_daal_homogen_table(std::int64_t column_count) {
43-
return daal::data_management::HomogenNumericTable<Data>::create(
44-
dal::detail::integral_cast<std::size_t>(column_count),
45-
dal::detail::integral_cast<std::size_t>(0),
46-
daal::data_management::NumericTable::notAllocate);
47-
}
48-
49-
template <typename Data>
50-
inline auto convert_to_daal_homogen_table(array<Data>& data,
51-
std::int64_t row_count,
52-
std::int64_t column_count,
53-
bool allow_copy = false) {
54-
if (!data.get_count()) {
55-
return daal::services::SharedPtr<daal::data_management::HomogenNumericTable<Data>>();
56-
}
57-
58-
if (allow_copy) {
59-
data.need_mutable_data();
60-
}
61-
62-
ONEDAL_ASSERT(data.has_mutable_data());
63-
const auto daal_data =
64-
daal::services::SharedPtr<Data>(data.get_mutable_data(), daal_object_owner{ data });
65-
66-
return daal::data_management::HomogenNumericTable<Data>::create(
67-
daal_data,
68-
dal::detail::integral_cast<std::size_t>(column_count),
69-
dal::detail::integral_cast<std::size_t>(row_count));
70-
}
71-
72-
template <typename Data>
73-
inline daal::data_management::NumericTablePtr copy_to_daal_homogen_table(const table& table) {
74-
// TODO: Preserve information about features
75-
const bool allow_copy = true;
76-
auto rows = row_accessor<const Data>{ table }.pull();
77-
return convert_to_daal_homogen_table(rows,
78-
table.get_row_count(),
79-
table.get_column_count(),
80-
allow_copy);
81-
}
82-
83-
template <typename Data>
84-
inline table convert_from_daal_homogen_table(const daal::data_management::NumericTablePtr& nt) {
85-
if (nt->getNumberOfRows() == 0) {
86-
return table{};
87-
}
88-
daal::data_management::BlockDescriptor<Data> block;
89-
const std::int64_t row_count = dal::detail::integral_cast<std::int64_t>(nt->getNumberOfRows());
90-
const std::int64_t column_count =
91-
dal::detail::integral_cast<std::int64_t>(nt->getNumberOfColumns());
92-
93-
nt->getBlockOfRows(0, row_count, daal::data_management::readOnly, block);
94-
Data* data = block.getBlockPtr();
95-
array<Data> arr(data, row_count * column_count, [nt, block](Data* p) mutable {
96-
nt->releaseBlockOfRows(block);
97-
});
98-
return detail::homogen_table_builder{}.reset(arr, row_count, column_count).build();
99-
}
100-
101-
inline daal::data_management::NumericTablePtr wrap_by_host_homogen_adapter(
102-
const homogen_table& table) {
103-
const auto& dtype = table.get_metadata().get_data_type(0);
104-
105-
switch (dtype) {
106-
case data_type::float32: return host_homogen_table_adapter<float>::create(table);
107-
case data_type::float64: return host_homogen_table_adapter<double>::create(table);
108-
case data_type::int32: return host_homogen_table_adapter<std::int32_t>::create(table);
109-
default: return daal::data_management::NumericTablePtr();
110-
}
111-
}
112-
113-
inline daal::data_management::NumericTablePtr wrap_by_host_soa_adapter(const homogen_table& table) {
114-
const auto& dtype = table.get_metadata().get_data_type(0);
115-
116-
switch (dtype) {
117-
case data_type::float32: return host_soa_table_adapter::create<float>(table);
118-
case data_type::float64: return host_soa_table_adapter::create<double>(table);
119-
case data_type::int32: return host_soa_table_adapter::create<std::int32_t>(table);
120-
default: return daal::data_management::NumericTablePtr();
121-
}
122-
}
123-
124-
template <typename Data>
125-
inline daal::data_management::NumericTablePtr convert_to_daal_table(const homogen_table& table) {
126-
if (table.get_data_layout() == data_layout::row_major) {
127-
if (auto wrapper = wrap_by_host_homogen_adapter(table)) {
128-
return wrapper;
129-
}
130-
}
131-
else if (table.get_data_layout() == data_layout::column_major) {
132-
if (auto wrapper = wrap_by_host_soa_adapter(table)) {
133-
return wrapper;
134-
}
135-
}
136-
return copy_to_daal_homogen_table<Data>(table);
137-
}
138-
139-
template <typename T>
140-
inline auto convert_to_daal_csr_table(array<T>& data,
141-
array<std::int64_t>& column_indices,
142-
array<std::int64_t>& row_indices,
143-
std::int64_t row_count,
144-
std::int64_t column_count,
145-
bool allow_copy = false) {
146-
ONEDAL_ASSERT(data.get_count() == column_indices.get_count());
147-
ONEDAL_ASSERT(row_indices.get_count() == row_count + 1);
148-
149-
if (!data.get_count() || !column_indices.get_count() || !row_indices.get_count()) {
150-
return daal::services::SharedPtr<daal::data_management::CSRNumericTable>();
151-
}
152-
153-
if (allow_copy) {
154-
data.need_mutable_data();
155-
column_indices.need_mutable_data();
156-
row_indices.need_mutable_data();
157-
}
158-
159-
ONEDAL_ASSERT(data.has_mutable_data());
160-
ONEDAL_ASSERT(column_indices.has_mutable_data());
161-
ONEDAL_ASSERT(row_indices.has_mutable_data());
162-
163-
const auto daal_data =
164-
daal::services::SharedPtr<T>(data.get_mutable_data(), daal_object_owner{ data });
165-
ONEDAL_ASSERT(sizeof(std::size_t) == sizeof(std::int64_t));
166-
const auto daal_column_indices = daal::services::SharedPtr<std::size_t>(
167-
reinterpret_cast<std::size_t*>(column_indices.get_mutable_data()),
168-
daal_object_owner{ column_indices });
169-
const auto daal_row_indices = daal::services::SharedPtr<std::size_t>(
170-
reinterpret_cast<std::size_t*>(row_indices.get_mutable_data()),
171-
daal_object_owner{ row_indices });
172-
173-
return daal::data_management::CSRNumericTable::create(
174-
daal_data,
175-
daal_column_indices,
176-
daal_row_indices,
177-
dal::detail::integral_cast<std::size_t>(column_count),
178-
dal::detail::integral_cast<std::size_t>(row_count));
179-
}
180-
181-
template <typename Float>
182-
inline daal::data_management::CSRNumericTablePtr copy_to_daal_csr_table(const csr_table& table) {
183-
const bool allow_copy = true;
184-
auto [data, column_indices, row_offsets] = csr_accessor<const Float>{ table }.pull();
185-
return convert_to_daal_csr_table(data,
186-
column_indices,
187-
row_offsets,
188-
table.get_row_count(),
189-
table.get_column_count(),
190-
allow_copy);
191-
}
192-
193-
template <typename T>
194-
inline table convert_from_daal_csr_table(const daal::data_management::NumericTablePtr& nt) {
195-
auto block_owner = std::make_shared<csr_block_owner<T>>(csr_block_owner<T>{ nt });
196-
197-
ONEDAL_ASSERT(sizeof(std::size_t) == sizeof(std::int64_t));
198-
199-
return csr_table{
200-
array<T>{ block_owner->get_data(),
201-
block_owner->get_element_count(),
202-
[block_owner](const T* p) {} },
203-
array<std::int64_t>{ reinterpret_cast<std::int64_t*>(block_owner->get_column_indices()),
204-
block_owner->get_element_count(),
205-
[block_owner](const std::int64_t* p) {} },
206-
array<std::int64_t>{ reinterpret_cast<std::int64_t*>(block_owner->get_row_indices()),
207-
block_owner->get_row_count() + 1,
208-
[block_owner](const std::int64_t* p) {} },
209-
block_owner->get_column_count()
210-
};
211-
}
212-
213-
inline daal::data_management::CSRNumericTablePtr wrap_by_host_csr_adapter(const csr_table& table) {
214-
const auto& dtype = table.get_metadata().get_data_type(0);
215-
216-
switch (dtype) {
217-
case data_type::float32: return host_csr_table_adapter<float>::create(table);
218-
case data_type::float64: return host_csr_table_adapter<double>::create(table);
219-
case data_type::int32: return host_csr_table_adapter<std::int32_t>::create(table);
220-
default: return daal::data_management::CSRNumericTablePtr();
221-
}
222-
}
223-
224-
template <typename Float>
225-
inline daal::data_management::CSRNumericTablePtr convert_to_daal_table(const csr_table& table) {
226-
auto wrapper = wrap_by_host_csr_adapter(table);
227-
if (!wrapper) {
228-
return copy_to_daal_csr_table<Float>(table);
229-
}
230-
else {
231-
return wrapper;
232-
}
233-
}
234-
235-
template <typename Data>
236-
inline daal::data_management::NumericTablePtr convert_to_daal_table(const table& table) {
237-
if (table.get_kind() == homogen_table::kind()) {
238-
const auto& homogen = static_cast<const homogen_table&>(table);
239-
return convert_to_daal_table<Data>(homogen);
240-
}
241-
else if (table.get_kind() == csr_table::kind()) {
242-
const auto& csr = static_cast<const csr_table&>(table);
243-
return convert_to_daal_table<Data>(csr);
244-
}
245-
else {
246-
return copy_to_daal_homogen_table<Data>(table);
247-
}
248-
}
249-
250-
template <typename Data>
251-
inline table convert_from_daal_table(const daal::data_management::NumericTablePtr& nt) {
252-
if (nt->getDataLayout() == daal::data_management::NumericTableIface::StorageLayout::csrArray) {
253-
return convert_from_daal_csr_table<Data>(nt);
254-
}
255-
else {
256-
return convert_from_daal_homogen_table<Data>(nt);
257-
}
258-
}
259-
260-
#ifdef ONEDAL_DATA_PARALLEL
261-
inline daal::data_management::NumericTablePtr convert_to_daal_table(const sycl::queue& queue,
262-
const table& table) {
263-
if (!table.has_data()) {
264-
return daal::data_management::NumericTablePtr{};
265-
}
266-
return interop::sycl_table_adapter::create(queue, table);
267-
}
268-
269-
template <typename Data>
270-
inline daal::data_management::NumericTablePtr convert_to_daal_table(const sycl::queue& queue,
271-
const array<Data>& data,
272-
std::int64_t row_count,
273-
std::int64_t column_count) {
274-
using daal::services::Status;
275-
using daal::services::SharedPtr;
276-
using daal::services::internal::Buffer;
277-
using daal::data_management::internal::SyclHomogenNumericTable;
278-
using dal::detail::integral_cast;
279-
280-
ONEDAL_ASSERT(data.get_count() == row_count * column_count);
281-
ONEDAL_ASSERT(data.has_mutable_data());
282-
ONEDAL_ASSERT(is_same_context(queue, data));
283-
284-
const SharedPtr<Data> data_shared{ data.get_mutable_data(), daal_object_owner{ data } };
285-
286-
Status status;
287-
const Buffer<Data> buffer{ data_shared,
288-
integral_cast<std::size_t>(data.get_count()),
289-
queue,
290-
status };
291-
status_to_exception(status);
292-
293-
const auto table =
294-
SyclHomogenNumericTable<Data>::create(buffer,
295-
integral_cast<std::size_t>(column_count),
296-
integral_cast<std::size_t>(row_count),
297-
&status);
298-
status_to_exception(status);
299-
300-
return table;
301-
}
302-
#endif
303-
304-
} // namespace oneapi::dal::backend::interop
17+
#include "oneapi/dal/table/backend/interop/table_conversion.hpp"

cpp/oneapi/dal/detail/error_messages.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ MSG(object_does_not_provide_read_access_to_csr,
116116
"Given object does not provide read access to the block of CSR format")
117117
MSG(pull_column_interface_is_not_implemented,
118118
"Pull column interface is planned but not implemented yet")
119+
MSG(unsupported_table_conversion, "Unsupported table conversion")
119120

120121
/* Ranges */
121122
MSG(invalid_range_of_rows, "Invalid range of rows")

0 commit comments

Comments
 (0)