Skip to content

Commit 5a9c8a6

Browse files
authored
Implement dsjson parsing, rust side pooling (#8)
* Implement dsjson parsing * Implement parse, setup, learn, predict * add comment * add comment * improve error message * fix leak * use size_t * pooling, example usage for parallel training * add to test * Formatting * rearrange func fixes segfault which is concerning * dont generate layout tests
1 parent 001fc8f commit 5a9c8a6

23 files changed

+1503
-113
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,6 @@ Cargo.lock
44
build
55
.cache
66
target/
7+
.DS_Store
8+
example_datafile.json
9+
.idea/

.vscode/settings.json

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"clangd.arguments": ["-compile-commands-dir=bindings/build"]
3+
}

Cargo.toml

+2
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,6 @@
33
members = [
44
"vowpalwabbit",
55
"vowpalwabbit-sys",
6+
"par-dsjson",
7+
"par-dsjson-unbounded",
68
]

binding/.clang-format

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
BasedOnStyle: Google
2+
AccessModifierOffset: -2
3+
AlignAfterOpenBracket: DontAlign
4+
AlignOperands: false
5+
AllowShortBlocksOnASingleLine: true
6+
AllowShortCaseLabelsOnASingleLine: false
7+
AllowShortFunctionsOnASingleLine: All
8+
AllowShortIfStatementsOnASingleLine: true
9+
AllowShortLoopsOnASingleLine: true
10+
BreakBeforeBraces: Allman
11+
BreakConstructorInitializersBeforeComma: true
12+
ColumnLimit: 120
13+
SortIncludes: true
14+
IndentPPDirectives: AfterHash
15+
PointerAlignment: Left
16+
DerivePointerAlignment: false

binding/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON)
66
set(CMAKE_CXX_EXTENSIONS OFF)
77
set(CMAKE_VISIBILITY_INLINES_HIDDEN TRUE)
88
set(CMAKE_CXX_VISIBILITY_PRESET "hidden")
9+
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
910

1011
cmake_policy(SET CMP0091 NEW)
1112
set(VCPKG_OVERLAY_PORTS "${CMAKE_CURRENT_LIST_DIR}/overlay-ports")

binding/include/vw_rs_bindings/bindings.hpp

+81-15
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
#pragma once
22

3+
#include <cstddef>
4+
#include <cstdint>
5+
36
#if defined _WIN32 || defined __CYGWIN__
47
# ifdef VW_RS_BUILDING_DLL
58
# ifdef __GNUC__
@@ -25,34 +28,97 @@
2528
# endif
2629
#endif
2730

28-
// For operations which cannot fail under any circumstance (except out of memory) it is acceptable to omit the return code, and error holder.
29-
// If it is an operation which can fail, it must return an error code and accept the error message parameter for filling with failure info.
31+
// For operations which cannot fail under any circumstance (except out of memory) it is acceptable to omit the return
32+
// code, and error holder. If it is an operation which can fail, it must return an error code and accept the error
33+
// message parameter for filling with failure info.
3034

3135
extern "C"
3236
{
3337
static const int VW_STATUS_SUCCESS = 0;
3438
static const int VW_STATUS_FAIL = 1;
3539

40+
// Unfortunately a copy paste of the enum since bringing in the header is not
41+
// feasible and using externs would mean these are no longer constants
42+
enum class override_prediction_type_t : uint32_t
43+
{
44+
scalar,
45+
scalars,
46+
action_scores,
47+
pdf,
48+
action_probs,
49+
multiclass,
50+
multilabels,
51+
prob,
52+
multiclassprobs, // not in use (technically oaa.cc)
53+
decision_probs,
54+
action_pdf_value,
55+
active_multiclass,
56+
nopred
57+
};
58+
3659
struct VWWorkspace;
3760
struct VWExample;
3861
struct VWErrorMessage;
62+
struct VWMultiEx;
3963

40-
DLL_PUBLIC struct VWErrorMessage* VWErrorMessageCreate() noexcept;
41-
DLL_PUBLIC void VWErrorMessageDelete(struct VWErrorMessage* error_message_handle) noexcept;
64+
struct VWActionScores;
65+
66+
DLL_PUBLIC VWErrorMessage* VWErrorMessageCreate() noexcept;
67+
DLL_PUBLIC void VWErrorMessageDelete(VWErrorMessage* error_message) noexcept;
4268
// If there was no error message set, a nullptr is returned.
43-
DLL_PUBLIC const char* VWErrorMessageGetValue(const struct VWErrorMessage* error_message_handle) noexcept;
44-
DLL_PUBLIC void VWErrorMessageClearValue(struct VWErrorMessage* error_message_handle) noexcept;
69+
DLL_PUBLIC const char* VWErrorMessageGetValue(const VWErrorMessage* error_message) noexcept;
70+
DLL_PUBLIC void VWErrorMessageClearValue(VWErrorMessage* error_message) noexcept;
4571

46-
DLL_PUBLIC int VWWorkspaceInitialize(const char* const* tokens, int count, struct VWWorkspace** output_handle,
47-
struct VWErrorMessage* error_message) noexcept;
48-
DLL_PUBLIC void VWWorkspaceDelete(struct VWWorkspace* workspace_handle) noexcept;
72+
DLL_PUBLIC int VWWorkspaceInitialize(
73+
const char* const* tokens, int count, VWWorkspace** output_handle, VWErrorMessage* error_message) noexcept;
74+
DLL_PUBLIC void VWWorkspaceDelete(VWWorkspace* workspace_handle) noexcept;
4975

50-
DLL_PUBLIC int VWWorkspaceLearn(struct VWWorkspace* workspace_handle, struct VWExample* example_handle, struct VWErrorMessage* error_message_handle) noexcept;
76+
DLL_PUBLIC int VWWorkspaceSetupExample(
77+
const VWWorkspace* workspace_handle, VWExample* example_handle, VWErrorMessage* error_message) noexcept;
78+
DLL_PUBLIC int VWWorkspaceSetupMultiEx(
79+
const VWWorkspace* workspace_handle, VWMultiEx* example_handle, VWErrorMessage* error_message) noexcept;
5180

52-
DLL_PUBLIC int VWWorkspaceGetPooledExample(struct VWWorkspace* workspace_handle, struct VWExample** output_handle, struct VWErrorMessage* error_message_handle) noexcept;
53-
DLL_PUBLIC int VWWorkspaceReturnPooledExample(struct VWWorkspace* workspace_handle, struct VWExample* example_handle, struct VWErrorMessage* error_message_handle) noexcept;
81+
DLL_PUBLIC int VWWorkspaceLearn(
82+
VWWorkspace* workspace_handle, VWExample* example_handle, VWErrorMessage* error_message) noexcept;
83+
DLL_PUBLIC int VWWorkspaceLearnMultiEx(
84+
VWWorkspace* workspace_handle, VWMultiEx* example_handle, VWErrorMessage* error_message) noexcept;
85+
// Will allocate a prediction based on the returned prediction_type. It must be deleted with the corresponding type
86+
// deleter.
87+
// TODO: tackle fact that predict sets test_only meaning that it is no longer able to be used in learn
88+
DLL_PUBLIC int VWWorkspacePredict(VWWorkspace* workspace_handle, VWExample* example_handle, void** prediction,
89+
uint32_t* prediction_type, VWErrorMessage* error_message) noexcept;
90+
// Will allocate a prediction based on the returned prediction_type. It must be deleted with the corresponding type
91+
// deleter.
92+
DLL_PUBLIC int VWWorkspacePredictMultiEx(VWWorkspace* workspace_handle, VWMultiEx* example_handle, void** prediction,
93+
uint32_t* prediction_type, VWErrorMessage* error_message) noexcept;
5494

55-
DLL_PUBLIC struct VWExample* VWExampleCreate() noexcept;
56-
DLL_PUBLIC void VWExampleDelete(struct VWExample* example_handle) noexcept;
57-
}
95+
typedef VWExample* VWExampleFactoryFunc(void*);
96+
DLL_PUBLIC int VWWorkspaceParseDSJson(const VWWorkspace* workspace_handle, const char* json_string, size_t length, VWExampleFactoryFunc example_factory, void* example_factory_context,
97+
VWMultiEx* output_handle, VWErrorMessage* error_message) noexcept;
5898

99+
DLL_PUBLIC VWExample* VWExampleCreate() noexcept;
100+
DLL_PUBLIC void VWExampleDelete(VWExample* example_handle) noexcept;
101+
DLL_PUBLIC void VWExampleClear(VWExample* example_handle) noexcept;
102+
103+
DLL_PUBLIC VWMultiEx* VWMultiExCreate() noexcept;
104+
// If any examples are held in the container they will be deleted too.
105+
DLL_PUBLIC void VWMultiExDelete(VWMultiEx* example_handle) noexcept;
106+
DLL_PUBLIC size_t VWMultiGetLength(const VWMultiEx* example_handle) noexcept;
107+
// Returns a pointer to that example.
108+
DLL_PUBLIC int VWMultiGetExampleAt(
109+
VWMultiEx* example_handle, VWExample** example, size_t index, VWErrorMessage* error_message) noexcept;
110+
// Releases the example at the index. Removes it from the collection and its lifetime must be managed by the caller.
111+
DLL_PUBLIC int VWMultiReleaseExampleAt(
112+
VWMultiEx* example_handle, VWExample** example, size_t index, VWErrorMessage* error_message) noexcept;
113+
// Deletes the example at that index
114+
DLL_PUBLIC int VWMultiDeleteExampleAt(
115+
VWMultiEx* example_handle, size_t index, VWErrorMessage* error_message) noexcept;
116+
// Lifetime transfers to the multiex. Use index == size to push at end.
117+
DLL_PUBLIC int VWMultiInsertExampleAt(
118+
VWMultiEx* example_handle, VWExample* example, size_t index, VWErrorMessage* error_message) noexcept;
119+
120+
DLL_PUBLIC void VWActionScoresDelete(VWActionScores* action_scores_handle) noexcept;
121+
DLL_PUBLIC void VWActionScoresGetLength(const VWActionScores* action_scores_handle, size_t* length) noexcept;
122+
DLL_PUBLIC int VWActionScoresGetValue(const VWActionScores* action_scores_handle, uint32_t* action, float* value,
123+
size_t index, VWErrorMessage* error_message) noexcept;
124+
}

0 commit comments

Comments
 (0)