Skip to content

Commit c888a0f

Browse files
authored
build tool for parallel parsing for training (#11)
* build tool for parallel parsing for training * formatting
1 parent 37fa079 commit c888a0f

File tree

15 files changed

+718
-320
lines changed

15 files changed

+718
-320
lines changed

Cargo.toml

+1-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,5 @@
33
members = [
44
"vowpalwabbit",
55
"vowpalwabbit-sys",
6-
"par-dsjson",
7-
"par-dsjson-unbounded",
6+
"tool",
87
]

binding/CMakeLists.txt

+13
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,19 @@ set(VCPKG_OVERLAY_PORTS "${CMAKE_CURRENT_LIST_DIR}/overlay-ports")
1313

1414
project(vowpalwabbit-rs-bindings LANGUAGES CXX)
1515

16+
if(VW_RS_ASAN)
17+
add_compile_definitions(VW_USE_ASAN)
18+
if(MSVC)
19+
add_compile_options(/fsanitize=address /GS- /wd5072)
20+
add_link_options(/InferASanLibs /incremental:no /debug)
21+
# Workaround for MSVC ASan issue here: https://developercommunity.visualstudio.com/t/VS2022---Address-sanitizer-on-x86-Debug-/10116361
22+
add_compile_definitions(_DISABLE_STRING_ANNOTATION)
23+
else()
24+
add_compile_options(-fsanitize=address -fno-omit-frame-pointer -g3)
25+
add_link_options(-fsanitize=address -fno-omit-frame-pointer -g3)
26+
endif()
27+
endif()
28+
1629
set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreaded$<$<CONFIG:Debug>:Debug>")
1730

1831
find_package(VowpalWabbit CONFIG REQUIRED)

binding/include/vw_rs_bindings/bindings.hpp

+8
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ extern "C"
8282
// bytes is a c string and must be deleted using delete buffer
8383
DLL_PUBLIC int VWWorkspaceSerializeReadableModel(const VWWorkspace* workspace_handle, const unsigned char** bytes, size_t* num_bytes, VWErrorMessage* error_message) noexcept;
8484

85+
DLL_PUBLIC int VWWorkspaceEndPass(
86+
VWWorkspace* workspace_handle, VWErrorMessage* error_message) noexcept;
87+
8588
DLL_PUBLIC int VWWorkspaceSetupExample(
8689
const VWWorkspace* workspace_handle, VWExample* example_handle, VWErrorMessage* error_message) noexcept;
8790
DLL_PUBLIC int VWWorkspaceSetupMultiEx(
@@ -101,6 +104,11 @@ extern "C"
101104
DLL_PUBLIC int VWWorkspacePredictMultiEx(VWWorkspace* workspace_handle, VWMultiEx* example_handle, void** prediction,
102105
uint32_t* prediction_type, VWErrorMessage* error_message) noexcept;
103106

107+
DLL_PUBLIC int VWWorkspaceRecordExample(
108+
VWWorkspace* workspace_handle, VWExample* example_handle, VWErrorMessage* error_message) noexcept;
109+
DLL_PUBLIC int VWWorkspaceRecordMultiEx(
110+
VWWorkspace* workspace_handle, VWMultiEx* example_handle, VWErrorMessage* error_message) noexcept;
111+
104112
typedef VWExample* VWExampleFactoryFunc(void*);
105113
DLL_PUBLIC int VWWorkspaceParseDSJson(const VWWorkspace* workspace_handle, const char* json_string, size_t length, VWExampleFactoryFunc example_factory, void* example_factory_context,
106114
VWMultiEx* output_handle, VWErrorMessage* error_message) noexcept;

binding/src/bindings.cc

+54-11
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ void VWErrorMessageClearValue(VWErrorMessage* error_message) noexcept
123123
error_message->clear();
124124
}
125125

126+
DLL_PUBLIC void VWWorkspaceDeleteBuffer(const unsigned char* buffer) noexcept { delete[] buffer; }
127+
126128
// VWWorkspace
127129

128130
DLL_PUBLIC int VWWorkspaceInitialize(
@@ -138,8 +140,8 @@ try
138140
}
139141
CATCH_RETURN_EXCEPTION
140142

141-
DLL_PUBLIC int VWWorkspaceInitializeFromModel(
142-
const char* const* extra_tokens, size_t count, const unsigned char* bytes, size_t num_bytes, VWWorkspace** output_handle, VWErrorMessage* error_message) noexcept
143+
DLL_PUBLIC int VWWorkspaceInitializeFromModel(const char* const* extra_tokens, size_t count, const unsigned char* bytes,
144+
size_t num_bytes, VWWorkspace** output_handle, VWErrorMessage* error_message) noexcept
143145
try
144146
{
145147
std::vector<std::string> args(extra_tokens, extra_tokens + count);
@@ -157,7 +159,9 @@ DLL_PUBLIC void VWWorkspaceDelete(VWWorkspace* workspace_handle) noexcept
157159
delete workspace;
158160
}
159161

160-
DLL_PUBLIC int VWWorkspaceSerializeModel(const VWWorkspace* workspace_handle, const unsigned char** bytes, size_t* num_bytes, VWErrorMessage* error_message) noexcept try
162+
DLL_PUBLIC int VWWorkspaceSerializeModel(const VWWorkspace* workspace_handle, const unsigned char** bytes,
163+
size_t* num_bytes, VWErrorMessage* error_message) noexcept
164+
try
161165
{
162166
assert(workspace_handle != nullptr);
163167
auto* workspace = reinterpret_cast<const VW::workspace*>(workspace_handle);
@@ -173,7 +177,9 @@ DLL_PUBLIC int VWWorkspaceSerializeModel(const VWWorkspace* workspace_handle, co
173177
}
174178
CATCH_RETURN_EXCEPTION
175179

176-
DLL_PUBLIC int VWWorkspaceSerializeReadableModel(const VWWorkspace* workspace_handle, const unsigned char** bytes, size_t* num_bytes, VWErrorMessage* error_message) noexcept try
180+
DLL_PUBLIC int VWWorkspaceSerializeReadableModel(const VWWorkspace* workspace_handle, const unsigned char** bytes,
181+
size_t* num_bytes, VWErrorMessage* error_message) noexcept
182+
try
177183
{
178184
assert(workspace_handle != nullptr);
179185
auto* workspace = reinterpret_cast<const VW::workspace*>(workspace_handle);
@@ -189,10 +195,18 @@ DLL_PUBLIC int VWWorkspaceSerializeReadableModel(const VWWorkspace* workspace_ha
189195
}
190196
CATCH_RETURN_EXCEPTION
191197

192-
DLL_PUBLIC void VWWorkspaceDeleteBuffer(const unsigned char* buffer) noexcept
198+
DLL_PUBLIC int VWWorkspaceEndPass(VWWorkspace* workspace_handle, VWErrorMessage* error_message) noexcept
199+
try
193200
{
194-
delete[] buffer;
201+
assert(workspace_handle != nullptr);
202+
assert(example_handle != nullptr);
203+
204+
auto* workspace = reinterpret_cast<VW::workspace*>(workspace_handle);
205+
workspace->current_pass++;
206+
workspace->l->end_pass();
207+
return VW_STATUS_SUCCESS;
195208
}
209+
CATCH_RETURN_EXCEPTION
196210

197211
DLL_PUBLIC int VWWorkspaceSetupExample(
198212
const VWWorkspace* workspace_handle, VWExample* example_handle, VWErrorMessage* error_message) noexcept
@@ -282,8 +296,35 @@ try
282296
}
283297
CATCH_RETURN_EXCEPTION
284298

285-
DLL_PUBLIC int VWWorkspaceParseDSJson(const VWWorkspace* workspace_handle, const char* json_string, size_t length, VWExampleFactoryFunc example_factory, void* example_factory_context,
286-
VWMultiEx* output_handle, VWErrorMessage* error_message) noexcept
299+
DLL_PUBLIC int VWWorkspaceRecordExample(
300+
VWWorkspace* workspace_handle, VWExample* example_handle, VWErrorMessage* error_message) noexcept
301+
try
302+
{
303+
assert(workspace_handle != nullptr);
304+
assert(example_handle != nullptr);
305+
auto* workspace = reinterpret_cast<VW::workspace*>(workspace_handle);
306+
auto* ex = reinterpret_cast<VW::example*>(example_handle);
307+
workspace->finish_example(*ex);
308+
return VW_STATUS_SUCCESS;
309+
}
310+
CATCH_RETURN_EXCEPTION
311+
312+
DLL_PUBLIC int VWWorkspaceRecordMultiEx(
313+
VWWorkspace* workspace_handle, VWMultiEx* example_handle, VWErrorMessage* error_message) noexcept
314+
try
315+
{
316+
assert(workspace_handle != nullptr);
317+
assert(example_handle != nullptr);
318+
auto* workspace = reinterpret_cast<VW::workspace*>(workspace_handle);
319+
auto* ex = reinterpret_cast<VW::multi_ex*>(example_handle);
320+
workspace->finish_example(*ex);
321+
return VW_STATUS_SUCCESS;
322+
}
323+
CATCH_RETURN_EXCEPTION
324+
325+
DLL_PUBLIC int VWWorkspaceParseDSJson(const VWWorkspace* workspace_handle, const char* json_string, size_t length,
326+
VWExampleFactoryFunc example_factory, void* example_factory_context, VWMultiEx* output_handle,
327+
VWErrorMessage* error_message) noexcept
287328
try
288329
{
289330
assert(workspace_handle != nullptr);
@@ -302,11 +343,12 @@ try
302343

303344
using example_factory_t = example& (*)(void*);
304345

305-
example_factory_t factory = [](void* context) -> VW::example& {
346+
example_factory_t factory = [](void* context) -> VW::example&
347+
{
306348
auto* conv = reinterpret_cast<Converter*>(context);
307349
auto* ex = reinterpret_cast<VW::example*>(conv->_func(conv->_ctx));
308350
return *ex;
309-
};
351+
};
310352
auto* workspace = const_cast<VW::workspace*>(reinterpret_cast<const VW::workspace*>(workspace_handle));
311353
auto* multi_ex = reinterpret_cast<VW::multi_ex*>(output_handle);
312354
assert(multi_ex->empty());
@@ -463,7 +505,8 @@ DLL_PUBLIC void VWActionScoresGetLength(const VWActionScores* action_scores_hand
463505
}
464506

465507
DLL_PUBLIC int VWActionScoresGetValue(const VWActionScores* action_scores_handle, uint32_t* action, float* value,
466-
size_t index, VWErrorMessage* error_message) noexcept try
508+
size_t index, VWErrorMessage* error_message) noexcept
509+
try
467510
{
468511
assert(action_scores_handle != nullptr);
469512
auto& a_s = *reinterpret_cast<const ACTION_SCORE::action_scores*>(action_scores_handle);

par-dsjson-unbounded/src/main.rs

-161
This file was deleted.

par-dsjson/Cargo.toml

-12
This file was deleted.

par-dsjson/src/main.rs

-46
This file was deleted.

0 commit comments

Comments
 (0)