Skip to content

Commit 24d1b3f

Browse files
authored
remove sync impl for examples and workspace, update tool (#12)
* remove sync impl for examples * move thread unsafety to tool itself * update tool impl
1 parent c888a0f commit 24d1b3f

File tree

4 files changed

+115
-19
lines changed

4 files changed

+115
-19
lines changed

tool/src/main.rs

+115-14
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ struct Cli {
3131
enum Commands {
3232
/// Train a VW model file. Currently only single pass and DSJSON input format is supported.
3333
Train(Train),
34+
/// Train a VW model file without any parallelization. Currently only single pass and DSJSON input format is supported.
35+
TrainOneThread(TrainOneThread),
3436
}
3537

3638
#[derive(Parser, Debug)]
@@ -89,6 +91,41 @@ struct Train {
8991
parse_threads: usize,
9092
}
9193

94+
#[derive(Parser, Debug)]
95+
struct TrainOneThread {
96+
#[clap(
97+
short,
98+
long,
99+
parse(from_os_str),
100+
help = "List of input files to process"
101+
)]
102+
input: Vec<PathBuf>,
103+
104+
#[clap(long, value_enum, default_value_t = InputFormat::Dsjson, help="Input format to interpret input files as")]
105+
input_format: InputFormat,
106+
107+
#[clap(
108+
short,
109+
long,
110+
parse(from_os_str),
111+
help = "If provided, writes the final trained model to this file"
112+
)]
113+
output_model: Option<PathBuf>,
114+
115+
#[clap(
116+
long,
117+
parse(from_os_str),
118+
help = "If provided, writes the final trained model as a readable model to this file. This is the same format as VW's --readable_model ..."
119+
)]
120+
readable_model: Option<PathBuf>,
121+
122+
#[clap(
123+
long,
124+
help = "VW arguments to use for model training. Some arguments are not permitted as they are for driver configuration in VW or managed by this tool. For example you cannot supply --data yourself."
125+
)]
126+
model_args: Option<String>,
127+
}
128+
92129
fn process_command_line(input: Option<String>) -> Result<Vec<String>> {
93130
let mut vw_args = match input {
94131
Some(value) => shlex::Shlex::new(&value).collect(),
@@ -194,26 +231,70 @@ fn process_command_line(input: Option<String>) -> Result<Vec<String>> {
194231
Ok(vw_args)
195232
}
196233

234+
pub struct UnsafeWorkspaceWrapper {
235+
pub workspace: UnsafeCell<Workspace>,
236+
}
237+
238+
impl UnsafeWorkspaceWrapper {
239+
pub fn as_ref(&self) -> &Workspace {
240+
unsafe { self.workspace.get().as_ref().unwrap() }
241+
}
242+
243+
pub fn as_mut(&self) -> &mut Workspace {
244+
unsafe { self.workspace.get().as_mut().unwrap() }
245+
}
246+
}
247+
248+
unsafe impl Send for UnsafeWorkspaceWrapper {}
249+
unsafe impl Sync for UnsafeWorkspaceWrapper {}
250+
251+
fn train_one_thread(args: TrainOneThread) -> Result<()> {
252+
let vw_args = process_command_line(args.model_args)?;
253+
let pool = ExamplePool::new();
254+
let mut workspace = Workspace::new(&vw_args)
255+
.with_context(|| format!("Failed to create workspace with args {:?}", vw_args))?;
256+
257+
for file in args.input {
258+
let file = File::open(file).expect("Failed to open file");
259+
for line in io::BufReader::new(file).lines() {
260+
let mut ex =
261+
workspace.setup(workspace.parse_decision_service_json(&line.unwrap(), &pool)?)?;
262+
workspace.learn(&mut ex)?;
263+
workspace.record_stats(&mut ex)?;
264+
pool.return_example(ex);
265+
}
266+
}
267+
workspace.end_pass()?;
268+
269+
if let Some(model_file) = args.output_model {
270+
fs::write(model_file, &*workspace.serialize_model()?)?;
271+
}
272+
273+
if let Some(model_file) = args.readable_model {
274+
fs::write(model_file, workspace.serialize_readable_model()?)?;
275+
}
276+
277+
Ok(())
278+
}
279+
197280
fn train(args: Train) -> Result<()> {
198281
rayon::ThreadPoolBuilder::new()
199282
.num_threads(args.parse_threads)
200283
.build_global()?;
201284

202285
let vw_args = process_command_line(args.model_args)?;
203286

204-
// TODO process illegal options.
205-
206287
let pool = ExamplePool::new();
207288

208-
// We use an unsafe cell, because parse_decision_service_json, and the learning code does not interact.
209-
let workspace: UnsafeCell<Workspace> = Workspace::new(&vw_args)
289+
let unsafe_workspace_cell: UnsafeCell<Workspace> = Workspace::new(&vw_args)
210290
.with_context(|| format!("Failed to create workspace with args {:?}", vw_args))?
211291
.into();
292+
// We use an unsafe cell, because parse_decision_service_json, and the learning code does not interact.
293+
let shareable_workspace: UnsafeWorkspaceWrapper = UnsafeWorkspaceWrapper {
294+
workspace: unsafe_workspace_cell,
295+
};
212296
let (tx, rx) = flume::bounded(args.queue_size);
213297

214-
let ws_ref = unsafe { workspace.get().as_ref().unwrap() };
215-
let ws = unsafe { workspace.get().as_mut().unwrap() };
216-
217298
std::thread::scope(|s| -> Result<()> {
218299
s.spawn(|| {
219300
for file in args.input {
@@ -233,7 +314,11 @@ fn train(args: Train) -> Result<()> {
233314
}
234315
let output_lines: Vec<_> = batch
235316
.into_par_iter()
236-
.map(|line| ws_ref.parse_decision_service_json(&line, &pool))
317+
.map(|line| {
318+
shareable_workspace
319+
.as_ref()
320+
.parse_decision_service_json(&line, &pool)
321+
})
237322
.collect();
238323

239324
for line in output_lines {
@@ -247,29 +332,33 @@ fn train(args: Train) -> Result<()> {
247332
std::mem::drop(tx);
248333
});
249334

335+
let unsafe_workspace_ref = shareable_workspace.as_mut();
336+
250337
loop {
251338
// TODO consider skipping broken examples.
252339
let res = rx.recv();
253340
match res {
254341
Ok(line) => {
255-
let mut ex = ws.setup(line?)?;
256-
ws.learn(&mut ex)?;
257-
ws.record_stats(&mut ex)?;
342+
let mut ex = unsafe_workspace_ref.setup(line?)?;
343+
unsafe_workspace_ref.learn(&mut ex)?;
344+
unsafe_workspace_ref.record_stats(&mut ex)?;
258345
pool.return_example(ex);
259346
}
260347
// Sender has been dropped. Stop here.
261348
Err(_) => break,
262349
}
263350
}
264-
ws.end_pass()?;
351+
unsafe_workspace_ref.end_pass()?;
265352
Ok(())
266353
})?;
354+
355+
let unsafe_workspace_ref = shareable_workspace.as_ref();
267356
if let Some(model_file) = args.output_model {
268-
fs::write(model_file, &*ws_ref.serialize_model()?)?;
357+
fs::write(model_file, &*unsafe_workspace_ref.serialize_model()?)?;
269358
}
270359

271360
if let Some(model_file) = args.readable_model {
272-
fs::write(model_file, ws_ref.serialize_readable_model()?)?;
361+
fs::write(model_file, unsafe_workspace_ref.serialize_readable_model()?)?;
273362
}
274363

275364
Ok(())
@@ -290,5 +379,17 @@ fn main() -> Result<()> {
290379
}
291380
train(args)
292381
}
382+
Commands::TrainOneThread(args) => {
383+
if args.input.is_empty() {
384+
let mut app = Cli::into_app();
385+
let sub = app
386+
.find_subcommand_mut("train")
387+
.expect("train must exist as a subcommand");
388+
sub.print_help()?;
389+
return Err(anyhow!("At least 1 input file is required."));
390+
// return;
391+
}
392+
train_one_thread(args)
393+
}
293394
}
294395
}

vowpalwabbit/src/example.rs

-2
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,7 @@ pub struct RawExample {
4040
}
4141

4242
unsafe impl Send for Example {}
43-
unsafe impl Sync for Example {}
4443
unsafe impl Send for RawExample {}
45-
unsafe impl Sync for RawExample {}
4644
impl RawExample {
4745
pub fn new() -> RawExample {
4846
unsafe {

vowpalwabbit/src/multi_example.rs

-2
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,7 @@ pub struct RawMultiExample {
2424
}
2525

2626
unsafe impl Send for MultiExample {}
27-
unsafe impl Sync for MultiExample {}
2827
unsafe impl Send for RawMultiExample {}
29-
unsafe impl Sync for RawMultiExample {}
3028

3129
impl RawMultiExample {
3230
pub fn new() -> RawMultiExample {

vowpalwabbit/src/workspace.rs

-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ pub struct Workspace {
1515
}
1616

1717
unsafe impl Send for Workspace {}
18-
unsafe impl Sync for Workspace {}
1918

2019
unsafe fn get_action_scores_or_probs(pred_ptr: *mut c_void) -> Vec<(u32, f32)> {
2120
let mut length = MaybeUninit::<size_t>::zeroed();

0 commit comments

Comments
 (0)