diff --git a/cpp/include/raft/sparse/matrix/detail/preprocessing.cuh b/cpp/include/raft/sparse/matrix/detail/preprocessing.cuh new file mode 100644 index 0000000000..727b476d63 --- /dev/null +++ b/cpp/include/raft/sparse/matrix/detail/preprocessing.cuh @@ -0,0 +1,206 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace raft::sparse::matrix::detail { + +/** + * This function creates a representation of input data (rows) that identifies when + * value changes in the input data. This function assumes data is sorted. + * + * @param[in] rows + * The input data + * @param[in] nnz + * The size of the input data. + * @param[in] counts + * The resulting representation of the index value changes of the input. Should be + * the same size as the input (nnz) + */ +__global__ void _scan(int* rows, int nnz, int* counts) +{ + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= nnz) { return; } + if (index == 0) { + counts[index] = 1; + return; + } + if (index < nnz) { + int curr_id = rows[index]; + int old_id = rows[index - 1]; + if (curr_id != old_id) { + counts[index] = 1; + } else { + counts[index] = 0; + } + } +} + +/** + * This function counts the occurrences of the input array. Uses modulo logic as a + * rudimentary hash (should be changed with better hash function). + * + * @param[in] cols + * The input data + * @param[in] nnz + * The size of the input data. + * @param[in] counts + * The resulting representation of the index value changes of the input. Should be + * the same size as the input (nnz) + * @param[in] feats + * The array that will house the occurrence counts + * @param[in] vocabSize + * The size of the occurrence counts array (feats). + */ +__global__ void _fit_compute_occurs(int* cols, int nnz, int* counts, int* feats, int vocabSize) +{ + int index = blockIdx.x * blockDim.x + threadIdx.x; + if ((index < nnz) && (counts[index] == 1)) { + int targetVal = cols[index]; + int vocab = targetVal % vocabSize; + while (targetVal == cols[index]) { + feats[vocab] = feats[vocab] + 1; + index++; + if (index >= nnz) { return; } + } + } +} + +/** + * This function calculates tfidf or bm25, depending on options supplied, from the + * values input array. + * + * @param[in] rows + * The input rows. + * @param[in] columns + * The input columns (features). + * @param[in] values + * The input values. + * @param[in] feat_id_count + * The array holding the feature(column) occurrence counts for all fitted inputs. + * @param[in] counts + * The array representing value changes in rows input. + * @param[in] out_values + * The array that will store calculated values, should be size NNZ. + * @param[in] vocabSize + * The number of the features (columns). + * @param[in] num_rows + * Total number of rows for all fitted inputs. + * @param[in] avgRowLen + * The average length of a row (sum of all values for each row). + * @param[in] k + * The bm25 formula variable. Helps with optimization. + * @param[in] b + * The bm25 formula variable. Helps with optimization. + * @param[in] nnz + * The size of the input arrays (rows, columns, values). + * @param[in] bm25 + * Boolean that activates bm25 calculation instead of tfidf + */ +__global__ void _transform(int* rows, + int* columns, + float* values, + int* feat_id_count, + int* counts, + float* out_values, + int num_rows, + float avgRowLen, + float k, + float b, + int nnz, + int vocabSize, + bool bm25 = false) +{ + int start_index = blockIdx.x * blockDim.x + threadIdx.x; + int index = start_index; + if (index < nnz && counts[index] == 1) { + int row_length = 0; + int targetVal = rows[index]; + while (targetVal == rows[index]) { + row_length += values[index]; + index++; + if (index >= nnz) { break; } + } + index = start_index; + float result; + while (targetVal == rows[index]) { + int col = columns[index]; + int vocab = col % vocabSize; + float tf = (float)values[index] / row_length; + double idf_in = (double)num_rows / feat_id_count[vocab]; + float idf = (float)raft::log(idf_in); + result = tf * idf; + if (bm25) { + float bm = ((k + 1) * tf) / (k * ((1.0f - b) + b * (row_length / avgRowLen)) + tf); + result = idf * bm; + } + out_values[index] = result; + index++; + if (index >= nnz) { break; } + } + } +} + +/** + * This function converts a raft csr matrix in to a coo (rows, columns,values) + * representation. + * + * @param[in] handle + * The input data + * @param[in] csr_in + * The input raft csr matrix. + * @param[in] rows + * The output rows from the csr conversion. + * @param[in] columns + * The output columns from the csr conversion. + * @param[in] values + * The output values from the csr conversion. + */ +template +void convert_csr_to_coo(raft::resources& handle, + raft::device_csr_matrix csr_in, + raft::device_vector_view rows, + raft::device_vector_view columns, + raft::device_vector_view values) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto nnz = csr_in.structure_view().get_nnz(); + auto indptr = csr_in.structure_view().get_indptr(); + auto indices = csr_in.structure_view().get_indices(); + auto vals = csr_in.view().get_elements(); + + raft::sparse::convert::csr_to_coo( + indptr.data(), (int)indptr.size(), rows.data_handle(), (int)nnz, stream); + raft::copy(columns.data_handle(), indices.data(), (int)nnz, stream); + raft::copy(values.data_handle(), vals.data(), (int)nnz, stream); +} + +} // namespace raft::sparse::matrix::detail diff --git a/cpp/include/raft/sparse/matrix/preprocessing.cuh b/cpp/include/raft/sparse/matrix/preprocessing.cuh new file mode 100644 index 0000000000..891dbc12eb --- /dev/null +++ b/cpp/include/raft/sparse/matrix/preprocessing.cuh @@ -0,0 +1,446 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +namespace raft::sparse::matrix { + +/** + * The class facilitates the creation a tfidf and bm25 encoding values for sparse matrices + * + * This class creates tfidf and bm25 encoding values for sparse matrices. It allows for + * batched matrix processing by calling fit for all matrix chunks. Once all matrices have + * been fitted, the user can use transform to actually produce the encoded values for each + * subset (chunk) matrix. + * + * @tparam ValueType + * Type of the values in the sparse matrix. + * @tparam IndexType + * Type of the indices associated with the values. + * + * @param[in] featIdCount + * An array that holds the count of how many different rows each feature occurs in. + * @param[in] fullIdLen + * The value that represents the total number of words seen during the fit process. + * @param[out] vocabSize + * A value that represents the number of features that exist for the matrices encoded. + * @param[out] numRows + * The number of rows observed during the fit process, accumulates over all fit calls. + */ +template +class SparseEncoder { + private: + int* featIdCount; + float fullIdLen; + int vocabSize; + int numRows; + + public: + SparseEncoder(int vocab_size); + SparseEncoder(std::map featIdValues, int num_rows, int full_id_len, int vocab_size); + ~SparseEncoder(); + void fit(raft::resources& handle, + raft::device_coo_matrix coo_in); + void fit(raft::resources& handle, + raft::device_csr_matrix csr_in); + void transform(raft::resources& handle, + raft::device_csr_matrix csr_in, + float* results, + bool bm25_on, + float k_param = 1.6f, + float b_param = 0.75f); + void transform(raft::resources& handle, + raft::device_coo_matrix coo_in, + float* results, + bool bm25_on, + float k_param = 1.6f, + float b_param = 0.75f); + + private: + void _fit(raft::resources& handle, + raft::device_vector_view rows, + raft::device_vector_view columns, + raft::device_vector_view values, + int num_rows); + void _fit_feats(IndexType* cols, IndexType* counts, IndexType nnz, IndexType* results); + void transform(raft::resources& handle, + raft::device_vector_view rows, + raft::device_vector_view columns, + raft::device_vector_view values, + IndexType nnz, + ValueType* results, + bool bm25_on, + float k_param = 1.6f, + float b_param = 0.75f); +}; + +/** + * This constructor creates the `SparseEncoder` class with a vocabSize equal to the + * int vocab parameter supplied. + * + * @tparam ValueType + * Type of the values in the sparse matrix. + * @tparam IndexType + * Type of the indices associated with the values. + * + * @param[in] vocab + * Value that represents the number of features that exist for the matrices encoded. + */ +template +SparseEncoder::SparseEncoder(int vocab) : vocabSize(vocab) +{ + cudaMallocManaged(&featIdCount, vocab * sizeof(int)); + fullIdLen = 0.0f; + numRows = 0; + for (int i = 0; i < vocabSize; i++) { + featIdCount[i] = 0; + } +} +/** + * This constructor creates the `SparseEncoder` class with a vocabSize equal to the + * int vocab parameter supplied. + * + * @tparam ValueType + * Type of the values in the sparse matrix. + * @tparam IndexType + * Type of the indices associated with the values. + * + * @param[in] featIdValues + * A map that consists of all the indices and values, to populate the featIdCount array. + * @param[in] num_rows + * Value that represents the number of rows observed during fit cycle. + * @param[in] full_id_len + * Value that represents the number overall number of features observed during the fit + * cycle. + * @param[in] vocab_size + * Value that represents the number of features that exist for the matrices encoded. + * */ +template +SparseEncoder::SparseEncoder(std::map featIdValues, + int num_rows, + int full_id_len, + int vocab_size) + : vocabSize(vocab_size), numRows(num_rows), fullIdLen(full_id_len) +{ + cudaMallocManaged(&featIdCount, vocabSize * sizeof(int)); + cudaMemset(featIdCount, 0, vocabSize * sizeof(int)); + + for (const auto& item : featIdValues) { + featIdCount[item.first] = item.second; + } +} + +/** + * This destructor deallocates/frees the reserved memory of the class. + * + * @tparam ValueType + * Type of the values in the sparse matrix. + * @tparam IndexType + * Type of the indices associated with the values. + * */ +template +SparseEncoder::~SparseEncoder() +{ + cudaFree(featIdCount); +} + +template +void SparseEncoder::_fit_feats(IndexType* cols, + IndexType* counts, + IndexType nnz, + IndexType* results) +{ + int blockSize = (nnz < 256) ? nnz : 256; + int num_blocks = (nnz + blockSize - 1) / blockSize; + raft::sparse::matrix::detail::_scan<<>>(cols, nnz, counts); + raft::sparse::matrix::detail::_fit_compute_occurs<<>>( + cols, nnz, counts, results, vocabSize); +} + +template +void SparseEncoder::_fit(raft::resources& handle, + raft::device_vector_view rows, + raft::device_vector_view columns, + raft::device_vector_view values, + int num_rows) +{ + numRows += num_rows; + IndexType nnz = values.size(); + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto batchIdLen = raft::make_host_scalar(handle, 0); + auto values_mat = raft::make_device_scalar(handle, 0); + raft::linalg::mapReduce(values_mat.data_handle(), + nnz, + 0.0f, + raft::identity_op(), + raft::add_op(), + stream, + values.data_handle()); + raft::copy(batchIdLen.data_handle(), values_mat.data_handle(), values_mat.size(), stream); + fullIdLen += (ValueType)batchIdLen(0); + auto d_rows = raft::make_device_vector(handle, nnz); + auto d_cols = raft::make_device_vector(handle, nnz); + auto d_vals = raft::make_device_vector(handle, nnz); + raft::copy(d_rows.data_handle(), rows.data_handle(), nnz, stream); + raft::copy(d_cols.data_handle(), columns.data_handle(), nnz, stream); + raft::copy(d_vals.data_handle(), values.data_handle(), nnz, stream); + raft::sparse::op::coo_sort( + nnz, nnz, nnz, d_cols.data_handle(), d_rows.data_handle(), d_vals.data_handle(), stream); + IndexType* counts; + cudaMallocManaged(&counts, nnz * sizeof(IndexType)); + cudaMemset(counts, 0, nnz * sizeof(IndexType)); + _fit_feats(d_cols.data_handle(), counts, nnz, featIdCount); + cudaFree(counts); + cudaDeviceSynchronize(); +} + +/** + * This function fits the input matrix, recording required statistics to later create + * encoding values. + * + * @tparam ValueType + * Type of the values in the sparse matrix. + * @tparam IndexType + * Type of the indices associated with the values. + * + * @param[in] handle + * Container for managing reusable resources. + * @param[in] coo_in + * Raft container housing a coordinate format sparse matrix representation. + + * */ +template +void SparseEncoder::fit(raft::resources& handle, + raft::device_coo_matrix coo_in) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto n_rows = coo_in.structure_view().get_n_rows(); + auto rows = coo_in.structure_view().get_rows(); + auto cols = coo_in.structure_view().get_cols(); + auto vals = coo_in.view().get_elements(); + auto nnz = coo_in.structure_view().get_nnz(); + + auto d_rows = raft::make_device_vector(handle, nnz); + auto d_cols = raft::make_device_vector(handle, nnz); + auto d_vals = raft::make_device_vector(handle, nnz); + + raft::copy(d_rows.data_handle(), rows.data(), nnz, stream); + raft::copy(d_cols.data_handle(), cols.data(), nnz, stream); + raft::copy(d_vals.data_handle(), vals.data(), nnz, stream); + + _fit(handle, d_rows.view(), d_cols.view(), d_vals.view(), n_rows); +} + +/** + * This function fits the input matrix, recording required statistics to later create + * encoding values. + * + * @tparam ValueType + * Type of the values in the sparse matrix. + * @tparam IndexType + * Type of the indices associated with the values. + * + * @param[in] handle + * Container for managing reusable resources. + * @param[in] csr_in + * Raft container housing a compressed sparse row matrix representation. + + * */ +template +void SparseEncoder::fit(raft::resources& handle, + raft::device_csr_matrix csr_in) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto nnz = csr_in.structure_view().get_nnz(); + auto rows = raft::make_device_vector(handle, nnz); + auto columns = raft::make_device_vector(handle, nnz); + auto values = raft::make_device_vector(handle, nnz); + + raft::sparse::matrix::detail::convert_csr_to_coo( + handle, csr_in, rows.view(), columns.view(), values.view()); + _fit(handle, rows.view(), columns.view(), values.view(), csr_in.structure_view().get_n_rows()); +} +/** + * This function transforms the coo matrix based on statistics collected during fit + * cycle. + * + * @tparam ValueType + * Type of the values in the sparse matrix. + * @tparam IndexType + * Type of the indices associated with the values. + * + * @param[in] handle + * Container for managing reusable resources. + * @param[in] coo_in + * Raft container housing a compressed sparse row matrix representation. + + * */ +template +void SparseEncoder::transform( + raft::resources& handle, + raft::device_coo_matrix coo_in, + float* results, + bool bm25_on, + float k_param, + float b_param) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto rows = coo_in.structure_view().get_rows(); + auto cols = coo_in.structure_view().get_cols(); + auto vals = coo_in.view().get_elements(); + auto nnz = coo_in.structure_view().get_nnz(); + + auto d_rows = raft::make_device_vector(handle, nnz); + auto d_cols = raft::make_device_vector(handle, nnz); + auto d_vals = raft::make_device_vector(handle, nnz); + raft::copy(d_rows.data_handle(), rows.data(), nnz, stream); + raft::copy(d_cols.data_handle(), cols.data(), nnz, stream); + raft::copy(d_vals.data_handle(), vals.data(), nnz, stream); + + transform( + handle, d_rows.view(), d_cols.view(), d_vals.view(), nnz, results, bm25_on, k_param, b_param); +} + +/** + * This function transforms the csr matrix based on statistics collected during fit + * cycle. + * + * @tparam ValueType + * Type of the values in the sparse matrix. + * @tparam IndexType + * Type of the indices associated with the values. + * + * @param[in] handle + * Container for managing reusable resources. + * @param[in] csr_in + * Raft container housing a compressed sparse row matrix representation. + + * */ +template +void SparseEncoder::transform( + raft::resources& handle, + raft::device_csr_matrix csr_in, + float* results, + bool bm25_on, + float k_param, + float b_param) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto nnz = csr_in.structure_view().get_nnz(); + auto rows = raft::make_device_vector(handle, nnz); + auto columns = raft::make_device_vector(handle, nnz); + auto values = raft::make_device_vector(handle, nnz); + raft::sparse::matrix::detail::convert_csr_to_coo( + handle, csr_in, rows.view(), columns.view(), values.view()); + transform( + handle, rows.view(), columns.view(), values.view(), nnz, results, bm25_on, k_param, b_param); +} + +template +void SparseEncoder::transform( + raft::resources& handle, + raft::device_vector_view rows, + raft::device_vector_view columns, + raft::device_vector_view values, + IndexType nnz, + ValueType* results, + bool bm25_on, + float k_param, + float b_param) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + + int maxLimit = nnz; + int blockSize = (maxLimit < 128) ? maxLimit : 128; + int num_blocks = (maxLimit + blockSize - 1) / blockSize; + float avgIdLen = (ValueType)fullIdLen / numRows; + int* counts; + cudaMallocManaged(&counts, maxLimit * sizeof(IndexType)); + cudaMemset(counts, 0, maxLimit * sizeof(IndexType)); + raft::sparse::op::coo_sort(IndexType(rows.size()), + IndexType(columns.size()), + IndexType(values.size()), + rows.data_handle(), + columns.data_handle(), + values.data_handle(), + stream); + raft::sparse::matrix::detail::_scan<<>>(rows.data_handle(), nnz, counts); + raft::sparse::matrix::detail::_transform<<>>(rows.data_handle(), + columns.data_handle(), + values.data_handle(), + featIdCount, + counts, + results, + numRows, + avgIdLen, + k_param, + b_param, + nnz, + vocabSize, + bm25_on); + cudaFree(counts); + cudaDeviceSynchronize(); +} +} // namespace raft::sparse::matrix diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 9f96b93e7a..3c290f6089 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -256,6 +256,7 @@ if(BUILD_TESTS) sparse/masked_matmul.cu sparse/norm.cu sparse/normalize.cu + sparse/preprocess.cu sparse/reduce.cu sparse/row_op.cu sparse/sddmm.cu diff --git a/cpp/tests/sparse/preprocess.cu b/cpp/tests/sparse/preprocess.cu new file mode 100644 index 0000000000..84c16c34f5 --- /dev/null +++ b/cpp/tests/sparse/preprocess.cu @@ -0,0 +1,300 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" +#include "../util/preprocess_utils.cu" + +#include +#include +#include +#include +#include + +#include + +#include +#include + +namespace raft { +namespace sparse { + +template +void get_clean_coo(raft::resources& handle, + raft::device_vector_view rows, + raft::device_vector_view columns, + raft::device_vector_view values, + int nnz, + int num_rows, + int num_cols, + raft::sparse::COO& coo) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + raft::sparse::op::coo_sort(int(rows.size()), + int(columns.size()), + int(values.size()), + rows.data_handle(), + columns.data_handle(), + values.data_handle(), + stream); + + // raft::sparse::COO coo(stream); + raft::sparse::op::max_duplicates(handle, + coo, + rows.data_handle(), + columns.data_handle(), + values.data_handle(), + nnz, + num_rows, + num_cols); +} +template +raft::device_coo_matrix +create_coo_matrix(raft::resources& handle, + raft::device_vector_view rows, + raft::device_vector_view columns, + raft::device_vector_view values, + int num_rows, + int num_cols) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto coo_struct_view = raft::make_device_coordinate_structure_view( + rows.data_handle(), columns.data_handle(), num_rows, num_cols, int(rows.size())); + auto c_matrix = raft::make_device_coo_matrix(handle, coo_struct_view); + raft::update_device( + c_matrix.view().get_elements().data(), values.data_handle(), int(values.size()), stream); + return c_matrix; +} + +template +raft::device_coo_matrix +create_coo_matrix(raft::resources& handle, raft::sparse::COO& coo) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto coo_struct_view = raft::make_device_coordinate_structure_view( + coo.rows(), coo.cols(), coo.n_rows, coo.n_cols, int(coo.nnz)); + auto c_matrix = raft::make_device_coo_matrix(handle, coo_struct_view); + raft::update_device(c_matrix.view().get_elements().data(), coo.vals(), coo.nnz, stream); + return c_matrix; +} + +template +struct SparsePreprocessInputs { + int n_rows; + int n_cols; + int nnz_edges; +}; + +template +class SparsePreprocessCSR + : public ::testing::TestWithParam> { + public: + SparsePreprocessCSR() + : params(::testing::TestWithParam>::GetParam()), + stream(resource::get_cuda_stream(handle)) + { + } + + protected: + void SetUp() override {} + + void Run(bool bm25_on, bool coo_on) + { + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + int num_rows = pow(2, params.n_rows); + int num_cols = pow(2, params.n_cols); + int nnz = params.nnz_edges; + auto a_rows = raft::make_device_vector(handle, nnz); + auto a_columns = raft::make_device_vector(handle, nnz); + auto a_values = raft::make_device_vector(handle, nnz); + + raft::util::create_dataset( + handle, a_rows.view(), a_columns.view(), a_values.view(), 5, params.n_rows, params.n_cols); + + auto b_rows = raft::make_device_vector(handle, nnz); + auto b_columns = raft::make_device_vector(handle, nnz); + auto b_values = raft::make_device_vector(handle, nnz); + raft::util::create_dataset(handle, + b_rows.view(), + b_columns.view(), + b_values.view(), + 5, + params.n_rows, + params.n_cols, + 67584); + + raft::sparse::COO coo_a(stream); + get_clean_coo( + handle, a_rows.view(), a_columns.view(), a_values.view(), nnz, num_rows, num_cols, coo_a); + + raft::sparse::COO coo_b(stream); + get_clean_coo( + handle, b_rows.view(), b_columns.view(), b_values.view(), nnz, num_rows, num_cols, coo_b); + + auto b_rows_h = raft::make_host_vector(handle, coo_b.nnz); + raft::copy(b_rows_h.data_handle(), coo_b.rows(), coo_b.nnz, stream); + for (int i = 0; i < coo_b.nnz; i++) { + int new_val = b_rows_h(i) + coo_a.n_rows; + b_rows_h(i) = new_val; + } + auto b_rows_stack = raft::make_device_vector(handle, coo_b.nnz); + raft::copy(b_rows_stack.data_handle(), b_rows_h.data_handle(), coo_b.nnz, stream); + + int ab_nnz = coo_b.nnz + coo_a.nnz; + auto ab_rows = raft::make_device_vector(handle, ab_nnz); + auto ab_columns = raft::make_device_vector(handle, ab_nnz); + auto ab_values = raft::make_device_vector(handle, ab_nnz); + + raft::copy(ab_rows.data_handle(), coo_a.rows(), coo_a.nnz, stream); + raft::copy(ab_rows.data_handle() + coo_a.nnz, b_rows_stack.data_handle(), coo_b.nnz, stream); + raft::copy(ab_columns.data_handle(), coo_a.cols(), coo_a.nnz, stream); + raft::copy(ab_columns.data_handle() + coo_a.nnz, coo_b.cols(), coo_b.nnz, stream); + raft::copy(ab_values.data_handle(), coo_a.vals(), coo_a.nnz, stream); + raft::copy(ab_values.data_handle() + coo_a.nnz, coo_b.vals(), coo_b.nnz, stream); + + int merged_num_rows = coo_a.n_rows + coo_b.n_rows; + + auto rows_csr = raft::make_device_vector(handle, merged_num_rows + 1); + raft::sparse::convert::sorted_coo_to_csr( + ab_rows.data_handle(), ab_rows.size(), rows_csr.data_handle(), merged_num_rows + 1, stream); + auto csr_struct_view = raft::make_device_compressed_structure_view(rows_csr.data_handle(), + ab_columns.data_handle(), + merged_num_rows, + num_cols, + int(ab_values.size())); + auto csr_matrix = + raft::make_device_csr_matrix(handle, csr_struct_view); + raft::update_device(csr_matrix.view().get_elements().data(), + ab_values.data_handle(), + int(ab_values.size()), + stream); + + auto result = raft::make_device_vector(handle, int(ab_values.size())); + raft::sparse::matrix::SparseEncoder* sparseEncoder = + new raft::sparse::matrix::SparseEncoder(num_cols); + + if (coo_on) { + auto a_matrix = create_coo_matrix(handle, coo_a); + auto b_matrix = create_coo_matrix(handle, coo_b); + auto c_matrix = create_coo_matrix( + handle, ab_rows.view(), ab_columns.view(), ab_values.view(), num_rows * 2, num_cols); + sparseEncoder->fit(handle, a_matrix); + sparseEncoder->fit(handle, b_matrix); + sparseEncoder->save(handle, "test_save.txt"); + sparseEncoder = + raft::sparse::matrix::loadSparseEncoder(handle, "test_save.txt"); + sparseEncoder->transform(handle, c_matrix, result.data_handle(), bm25_on); + } else { + sparseEncoder->fit(handle, csr_matrix); + sparseEncoder->transform(handle, csr_matrix, result.data_handle(), bm25_on); + } + delete sparseEncoder; + + if (bm25_on) { + auto bm25_vals = raft::make_device_vector(handle, int(ab_values.size())); + raft::util::calc_tfidf_bm25(handle, csr_matrix.view(), bm25_vals.view()); + ASSERT_TRUE(raft::devArrMatch(bm25_vals.data_handle(), + result.data_handle(), + result.size(), + raft::CompareApprox(2e-5), + stream)); + } else { + auto tfidf_vals = raft::make_device_vector(handle, int(ab_values.size())); + raft::util::calc_tfidf_bm25( + handle, csr_matrix.view(), tfidf_vals.view(), true); + ASSERT_TRUE(raft::devArrMatch(tfidf_vals.data_handle(), + result.data_handle(), + result.size(), + raft::CompareApprox(2e-5), + stream)); + } + + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + } + + protected: + raft::resources handle; + cudaStream_t stream; + + SparsePreprocessInputs params; +}; + +using SparsePreprocessFF = SparsePreprocessCSR; +TEST_P(SparsePreprocessFF, Result) { Run(false, false); } + +using SparsePreprocessTT = SparsePreprocessCSR; +TEST_P(SparsePreprocessTT, Result) { Run(true, true); } + +using SparsePreprocessFT = SparsePreprocessCSR; +TEST_P(SparsePreprocessFT, Result) { Run(false, true); } + +using SparsePreprocessTF = SparsePreprocessCSR; +TEST_P(SparsePreprocessTF, Result) { Run(true, false); } + +using SparsePreprocessBigFF = SparsePreprocessCSR; +TEST_P(SparsePreprocessBigFF, Result) { Run(false, false); } + +using SparsePreprocessBigTT = SparsePreprocessCSR; +TEST_P(SparsePreprocessBigTT, Result) { Run(true, true); } + +using SparsePreprocessBigFT = SparsePreprocessCSR; +TEST_P(SparsePreprocessBigFT, Result) { Run(false, true); } + +using SparsePreprocessBigTF = SparsePreprocessCSR; +TEST_P(SparsePreprocessBigTF, Result) { Run(true, false); } + +const std::vector> sparse_preprocess_inputs = { + { + 7, // n_rows_factor + 5, // n_cols_factor + 100 // num nnz values + }, +}; + +const std::vector> sparse_preprocess_inputs_big = { + { + 14, // n_rows_factor + 14, // n_cols_factor + 500000 // nnz_edges + }, +}; + +INSTANTIATE_TEST_CASE_P(SparsePreprocessCSR, + SparsePreprocessFF, + ::testing::ValuesIn(sparse_preprocess_inputs)); +INSTANTIATE_TEST_CASE_P(SparsePreprocessCSR, + SparsePreprocessTT, + ::testing::ValuesIn(sparse_preprocess_inputs)); +INSTANTIATE_TEST_CASE_P(SparsePreprocessCSR, + SparsePreprocessFT, + ::testing::ValuesIn(sparse_preprocess_inputs)); +INSTANTIATE_TEST_CASE_P(SparsePreprocessCSR, + SparsePreprocessTF, + ::testing::ValuesIn(sparse_preprocess_inputs)); +INSTANTIATE_TEST_CASE_P(SparsePreprocessCSR, + SparsePreprocessBigTT, + ::testing::ValuesIn(sparse_preprocess_inputs_big)); +INSTANTIATE_TEST_CASE_P(SparsePreprocessCSR, + SparsePreprocessBigFF, + ::testing::ValuesIn(sparse_preprocess_inputs_big)); +INSTANTIATE_TEST_CASE_P(SparsePreprocessCSR, + SparsePreprocessBigTF, + ::testing::ValuesIn(sparse_preprocess_inputs_big)); +INSTANTIATE_TEST_CASE_P(SparsePreprocessCSR, + SparsePreprocessBigFT, + ::testing::ValuesIn(sparse_preprocess_inputs_big)); + +} // namespace sparse +} // namespace raft diff --git a/cpp/tests/util/preprocess_utils.cu b/cpp/tests/util/preprocess_utils.cu new file mode 100644 index 0000000000..d3870d67f8 --- /dev/null +++ b/cpp/tests/util/preprocess_utils.cu @@ -0,0 +1,250 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::util { + +template +void print_vals(raft::resources& handle, const raft::device_vector_view& out) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto h_out = raft::make_host_vector(out.size()); + raft::copy(h_out.data_handle(), out.data_handle(), out.size(), stream); + int limit = int(out.size()); + for (int i = 0; i < limit; i++) { + std::cout << float(h_out(i)) << ", "; + } + std::cout << std::endl; +} + +template +void print_vals(raft::resources& handle, T* out, int len) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto h_out = raft::make_host_vector(handle, len); + raft::copy(h_out.data_handle(), out, len, stream); + int limit = int(len); + for (int i = 0; i < limit; i++) { + std::cout << float(h_out(i)) << ", "; + } + std::cout << std::endl; +} + +template +struct check_zeroes { + float __device__ operator()(const T1& value, const T2& idx) + { + if (value == 0) { + return 0.f; + } else { + return 1.f; + } + } +}; + +template +void preproc(raft::resources& handle, + raft::device_vector_view dense_values, + raft::device_vector_view results, + int num_rows, + int num_cols, + bool tf_idf) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + + // create matrix and copy to device + auto host_dense_vals = raft::make_host_vector(handle, dense_values.size()); + raft::copy( + host_dense_vals.data_handle(), dense_values.data_handle(), dense_values.size(), stream); + + auto host_matrix = + raft::make_host_matrix_view(host_dense_vals.data_handle(), num_rows, num_cols); + auto device_matrix = raft::make_device_matrix(handle, num_rows, num_cols); + + raft::copy(device_matrix.data_handle(), host_matrix.data_handle(), host_matrix.size(), stream); + + // get sum reduce for each row (length of the document) + auto output_rows_lengths = raft::make_device_matrix(handle, 1, num_rows); + raft::linalg::reduce(output_rows_lengths.data_handle(), + device_matrix.data_handle(), + num_cols, + num_rows, + 0.0f, + true, + true, + stream); + auto h_output_rows_lengths = raft::make_host_matrix(handle, 1, num_rows); + raft::copy(h_output_rows_lengths.data_handle(), + output_rows_lengths.data_handle(), + output_rows_lengths.size(), + stream); + + // find the avg size of a document + auto output_rows_length_sum = raft::make_device_scalar(handle, 0); + raft::linalg::mapReduce(output_rows_length_sum.data_handle(), + num_rows, + 0.0f, + raft::identity_op(), + raft::add_op(), + stream, + output_rows_lengths.data_handle()); + auto h_output_rows_length_sum = raft::make_host_scalar(handle, 0); + raft::copy(h_output_rows_length_sum.data_handle(), + output_rows_length_sum.data_handle(), + output_rows_length_sum.size(), + stream); + T2 avg_row_length = (T2)h_output_rows_length_sum(0) / num_rows; + + // find the number of docs(row) each vocab(col) word is in + auto output_cols_cnt = raft::make_device_matrix(handle, 1, num_cols); + raft::linalg::reduce(output_cols_cnt.data_handle(), + device_matrix.data_handle(), + num_cols, + num_rows, + 0.0f, + true, + false, + stream, + false, + check_zeroes()); + auto h_output_cols_cnt = raft::make_host_matrix(handle, 1, num_cols); + raft::copy( + h_output_cols_cnt.data_handle(), output_cols_cnt.data_handle(), output_cols_cnt.size(), stream); + + // perform bm25/tfidf calculations + auto out_device_matrix = raft::make_device_matrix(handle, num_rows, num_cols); + raft::matrix::fill(handle, out_device_matrix.view(), 0.0f); + auto out_host_matrix = raft::make_host_matrix(handle, num_rows, num_cols); + auto out_host_vector = raft::make_host_vector(handle, results.size()); + + float k1 = 1.6f; + float b = 0.75f; + int count = 0; + float result; + for (int row = 0; row < num_rows; row++) { + for (int col = 0; col < num_cols; col++) { + float val = host_matrix(row, col); + // std::cout << val << ", "; + if (val == 0) { + out_host_matrix(row, col) = 0.0f; + } else { + float tf = (float)val / h_output_rows_lengths(0, row); + double idf_in = (double)num_rows / h_output_cols_cnt(0, col); + float idf = (float)raft::log(idf_in); + if (tf_idf) { + result = tf * idf; + } else { + float bm25 = ((k1 + 1) * tf) / + (k1 * ((1 - b) + b * (h_output_rows_lengths(0, row) / avg_row_length)) + tf); + result = idf * bm25; + } + out_host_matrix(row, col) = result; + out_host_vector(count) = result; + count++; + } + } + } + + raft::copy(results.data_handle(), out_host_vector.data_handle(), out_host_vector.size(), stream); +} + +template +void calc_tfidf_bm25(raft::resources& handle, + raft::device_csr_matrix_view csr_in, + raft::device_vector_view results, + bool tf_idf = false) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + int num_rows = csr_in.structure_view().get_n_rows(); + int num_cols = csr_in.structure_view().get_n_cols(); + int rows_size = csr_in.structure_view().get_indptr().size(); + int cols_size = csr_in.structure_view().get_indices().size(); + int elements_size = csr_in.get_elements().size(); + + auto indptr = raft::make_device_vector_view( + csr_in.structure_view().get_indptr().data(), rows_size); + auto indices = raft::make_device_vector_view( + csr_in.structure_view().get_indices().data(), cols_size); + auto values = + raft::make_device_vector_view(csr_in.get_elements().data(), elements_size); + auto dense_values = raft::make_device_vector(handle, num_rows * num_cols); + + cusparseHandle_t cu_handle; + RAFT_CUSPARSE_TRY(cusparseCreate(&cu_handle)); + + raft::sparse::convert::csr_to_dense(cu_handle, + num_rows, + num_cols, + elements_size, + indptr.data_handle(), + indices.data_handle(), + values.data_handle(), + num_rows, + dense_values.data_handle(), + stream, + true); + + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + preproc(handle, dense_values.view(), results, num_rows, num_cols, tf_idf); +} + +template +void create_dataset(raft::resources& handle, + raft::device_vector_view rows, + raft::device_vector_view columns, + raft::device_vector_view values, + int max_term_occurence_doc = 5, + int num_rows_unique = 7, + int num_cols_unique = 7, + int seed = 12345) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + raft::random::RngState rng(seed); + + auto d_out = raft::make_device_vector(handle, rows.size() * 2); + + int theta_guide = max(num_rows_unique, num_cols_unique); + auto theta = raft::make_device_vector(handle, theta_guide * 4); + + raft::random::uniform(handle, rng, theta.view(), 0.0f, 1.0f); + + raft::random::rmat_rectangular_gen(d_out.data_handle(), + rows.data_handle(), + columns.data_handle(), + theta.data_handle(), + num_rows_unique, + num_cols_unique, + int(values.size()), + stream, + rng); + + auto vals = raft::make_device_vector(handle, rows.size()); + raft::random::uniformInt(handle, rng, vals.view(), 1, max_term_occurence_doc); + raft::linalg::map(handle, values, raft::cast_op{}, raft::make_const_mdspan(vals.view())); +} + +}; // namespace raft::util