Skip to content

Commit da1aadc

Browse files
committed
Use legate buffer instead of vector
1 parent 90b6f23 commit da1aadc

File tree

1 file changed

+22
-7
lines changed

1 file changed

+22
-7
lines changed

src/models/tree/build_tree.cc

+22-7
Original file line numberDiff line numberDiff line change
@@ -30,29 +30,42 @@ namespace legateboost {
3030
namespace {
3131

3232
class BinnedX {
33-
// We could use int16 here if we store the indices offset from their minimum value
34-
std::vector<int32_t> data;
33+
// These are stored as int16_t to save space
34+
// Indices are relative to the feature, not the entire histogram
35+
// The maximum number of bins in legate-boost is 2048
36+
legate::Buffer<int16_t> data;
3537

3638
public:
3739
legate::Rect<3> shape;
40+
legate::Buffer<int32_t> row_pointers;
3841
int64_t num_features;
3942
int64_t num_rows;
4043
template <typename T>
4144
BinnedX(legate::AccessorRO<T, 3> X,
4245
legate::Rect<3> shape,
4346
const SparseSplitProposals<T>& split_proposals)
4447
: shape(shape),
45-
num_features(shape.hi[1] - shape.lo[1] + 1),
46-
num_rows(shape.hi[0] - shape.lo[0] + 1)
48+
row_pointers(legate::create_buffer<int32_t, 1>(split_proposals.row_pointers.size())),
49+
num_features(std::max(shape.hi[1] - shape.lo[1] + 1, 0ll)),
50+
num_rows(std::max(shape.hi[0] - shape.lo[0] + 1, 0ll))
4751
{
48-
data.resize(num_features * num_rows);
52+
data = legate::create_buffer<int16_t>(num_features * num_rows);
53+
std::copy(split_proposals.row_pointers.begin(),
54+
split_proposals.row_pointers.end(),
55+
row_pointers.ptr(0));
4956
for (int i = 0; i < num_rows; i++) {
5057
for (int j = 0; j < num_features; j++) {
51-
data[i * num_features + j] = split_proposals.FindBin(X[{i, j, 0}], j);
58+
auto bin_idx = split_proposals.FindBin(X[{i, j, 0}], j);
59+
// Store the bin index relative to the feature to save space
60+
data[i * num_features + j] = bin_idx - row_pointers[j];
5261
}
5362
}
5463
}
55-
int64_t operator[](const legate::Point<2>& p) const { return data[p[0] * num_features + p[1]]; }
64+
// This should use the local row index, not global
65+
int64_t operator[](const legate::Point<2>& p) const
66+
{
67+
return data[p[0] * num_features + p[1]] + row_pointers[p[1]];
68+
}
5669
};
5770

5871
struct NodeBatch {
@@ -194,6 +207,8 @@ auto SelectSplitSamples(legate::TaskContext context,
194207
split_proposals_tmp.insert(split_proposals_tmp.end(), unique.begin(), unique.end());
195208
}
196209

210+
draft_proposals.destroy();
211+
197212
auto split_proposals = legate::create_buffer<T, 1>(split_proposals_tmp.size());
198213
std::copy(split_proposals_tmp.begin(), split_proposals_tmp.end(), split_proposals.ptr(0));
199214

0 commit comments

Comments
 (0)