Skip to content

Commit d76d464

Browse files
committed
Resolve inline maps in arguments
1 parent a970f11 commit d76d464

File tree

5 files changed

+99
-53
lines changed

5 files changed

+99
-53
lines changed

README.md

+5
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,8 @@ or use a different experiment mode (set to testing mode and load stored model)
7070
```bash
7171
cargo run --release -- experiment.from_file.name=test_model experiment.model.from_file.name=stored experiment.model.path=pretrained.pt
7272
```
73+
74+
you can also use inline maps in TOML style
75+
76+
```
77+
cargo run --release -- 'dataloader.mix_snr = { Uniform = { low = 10, high = 30 }}'

tsap/src/error.rs

+7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use thiserror::Error;
2+
use std::convert::Infallible;
23

34
pub type Result<T> = std::result::Result<T, Error>;
45

@@ -24,3 +25,9 @@ pub enum Error {
2425
source: std::io::Error,
2526
},
2627
}
28+
29+
impl From<Infallible> for Error {
30+
fn from(_: Infallible) -> Self {
31+
unreachable!()
32+
}
33+
}

tsap/src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ mod error;
1111

1212
pub use error::{Result, Error};
1313
#[cfg(feature = "toml")]
14-
pub use toml_builder::{TomlBuilder, toml, serde};
14+
pub use toml_builder::{TomlBuilder, toml, serde, Path};
1515

1616
pub trait ParamGuard {
1717
type Error;

tsap/src/toml_builder.rs

+79-51
Original file line numberDiff line numberDiff line change
@@ -16,40 +16,53 @@ fn merge(mut root: Value, action: Action) -> (Value, Vec<Action>) {
1616
// first iterate through root until we are after our path base
1717
let mut local = &mut root;
1818
let paths = action.path().0.clone();
19-
let num = paths.len() - 1;
2019

21-
for path in &paths[..=num-1] {
22-
match local {
23-
Value::Table(ref mut t) if t.contains_key(path) =>
24-
local = t.get_mut(path).unwrap(),
20+
let deferred = if paths.len() == 0 {
21+
match action {
22+
Action::Delete(_) => { Vec::new() },
23+
Action::Set(_, val) =>
24+
merge_use_second(local, val, Mode::Set),
25+
Action::Modify(_, val) =>
26+
merge_use_second(local, val, Mode::Modify),
27+
}
28+
} else {
29+
let num = paths.len() - 1;
30+
31+
for path in &paths[..num] {
32+
match local {
33+
Value::Table(ref mut t) if t.contains_key(path) =>
34+
local = t.get_mut(path).unwrap(),
35+
_ => return (root, vec![action])
36+
};
37+
}
38+
39+
let local = match local {
40+
Value::Table(ref mut t) if t.contains_key(&paths[num]) => t,
2541
_ => return (root, vec![action])
2642
};
27-
}
2843

29-
let local = match local {
30-
Value::Table(ref mut t) if t.contains_key(&paths[num]) => t,
31-
_ => return (root, vec![action])
32-
};
3344

34-
// now that we are at the base, use recursive merging
35-
let deferred = match action {
36-
Action::Delete(_) => { local.remove(&paths[num]).unwrap(); Vec::new() },
37-
Action::Set(_, val) =>
38-
merge_use_second(local.get_mut(&paths[num]).unwrap(), val, Mode::Set),
39-
Action::Modify(_, val) =>
40-
merge_use_second(local.get_mut(&paths[num]).unwrap(), val, Mode::Modify),
45+
// now that we are at the base, use recursive merging
46+
match action {
47+
Action::Delete(_) => { local.remove(&paths[num]).unwrap(); Vec::new() },
48+
Action::Set(_, val) =>
49+
merge_use_second(local.get_mut(&paths[num]).unwrap(), val, Mode::Set),
50+
Action::Modify(_, val) =>
51+
merge_use_second(local.get_mut(&paths[num]).unwrap(), val, Mode::Modify),
52+
}
4153
};
4254

4355
// add base path again if something is deferred
4456
let deferred = deferred.into_iter().map(|mut p| {
4557
let mut new = paths.clone();
46-
new.append(&mut p.mut_path().0);
58+
new.extend(p.path().0.clone().into_iter().rev());
4759

4860
p.mut_path().0 = new;
4961

5062
p
5163
}).collect();
5264

65+
5366
(root, deferred)
5467
}
5568

@@ -98,13 +111,31 @@ impl FromStr for Path {
98111
type Err = Error;
99112

100113
fn from_str(path: &str) -> Result<Path> {
114+
if path.is_empty() {
115+
return Ok(Path(Vec::new()));
116+
}
117+
101118
let parsed_path = path.split('.').map(|x| x.to_string())
102119
.collect::<Vec<_>>();
103120

104121
Ok(Path(parsed_path))
105122
}
106123
}
107124

125+
impl From<&str> for Path {
126+
fn from(path: &str) -> Path {
127+
if path.is_empty() {
128+
return Path(Vec::new());
129+
}
130+
131+
let parsed_path = path.split('.').map(|x| x.to_string())
132+
.collect::<Vec<_>>();
133+
134+
Path(parsed_path)
135+
136+
}
137+
}
138+
108139
impl std::string::ToString for Path {
109140
fn to_string(&self) -> String {
110141
self.0.join(".")
@@ -223,17 +254,20 @@ impl TomlBuilder {
223254
}
224255

225256
let elm = if mode != Mode::Delete {
226-
if arg.matches('=').count() != 1 {
227-
return Err(Error::InvalidArg(arg));
228-
}
229-
230-
let elms = arg.splitn(2, '=').into_iter().collect::<Vec<_>>();
231-
let path = Path::from_str(&elms[0]).unwrap();
232-
let value = if let Ok(val) = elms[1].parse::<i64>() {
233-
Value::Integer(val)
234-
} else {
235-
Value::String(elms[1].to_string())
236-
};
257+
//if arg.matches('=').count() != 1 {
258+
// return Err(Error::InvalidArg(arg));
259+
//}
260+
261+
//let elms = arg.splitn(2, '=').into_iter().collect::<Vec<_>>();
262+
//let path = Path::from_str(&elms[0]).unwrap();
263+
let value = toml::from_str(&arg)?;
264+
let path = Path(Vec::new());
265+
266+
//let value = if let Ok(val) = elms[1].parse::<i64>() {
267+
// Value::Integer(val)
268+
//} else {
269+
// Value::String(elms[1].to_string())
270+
//};
237271

238272
match mode {
239273
Mode::Modify => Action::Modify(path, value),
@@ -293,27 +327,22 @@ impl TomlBuilder {
293327
}
294328
}
295329

296-
//pub fn amend_file<T: AsRef<Path>>(mut self, path: T) -> Result<Self> {
297-
// let mut f = File::open(path)?;
298-
// let mut content = String::new();
299-
// f.read_to_string(&mut content)?;
300-
// let root: toml::Value = content.parse()?;
301-
// let root = self.templates.resolve(root);
302-
// // merge both dictionaries
303-
// self.root = merge_use_second(self.root, root)?;
304-
// Ok(self)
305-
//}
306-
307-
//pub fn amend<T: TryInto<Value>>(mut self, val: T) -> Result<Self>
308-
// where Error: From<<T as TryInto<Value>>::Error> {
309-
// let root = val.try_into()?;
310-
// let root = self.templates.resolve(root);
311-
312-
// // merge both dictionaries
313-
// self.root = merge_use_second(self.root, root)?;
314-
315-
// Ok(self)
316-
//}
330+
pub fn amend_file<T: AsRef<std::path::Path>>(self, path: T) -> Result<Self> {
331+
let mut f = File::open(path)?;
332+
let mut content = String::new();
333+
f.read_to_string(&mut content)?;
334+
335+
self.amend("", &content)
336+
}
337+
338+
pub fn amend<P: Into<Path>, T: AsRef<str>>(mut self, path: P, val: T) -> Result<Self> {
339+
let path = path.into();
340+
let root = toml::from_str(val.as_ref())?;
341+
342+
self.actions.push(Action::Set(path, root));
343+
344+
Ok(self)
345+
}
317346

318347
pub fn root(self) -> toml::Value {
319348
self.root
@@ -336,6 +365,5 @@ mod tests {
336365
path = 'data/cifar10/'
337366
"#;
338367
let builder: TomlBuilder = content.try_into().unwrap();
339-
dbg!(&builder.root);
340368
}
341369
}

tsap_macro/src/lower_toml.rs

+7-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,13 @@ impl Intermediate {
9999

100100
impl #builder_name {
101101
pub fn amend_file<T: AsRef<std::path::Path>>(mut self, path: T) -> Result<#builder_name, <#item2 as ParamGuard>::Error> {
102-
//self.0 = self.0.amend_file(path)?;
102+
self.0 = self.0.amend_file(path)?;
103+
104+
Ok(self)
105+
}
106+
107+
pub fn amend<P: Into<tsap::Path>, T: AsRef<str>>(mut self, path: P, val: T) -> Result<Self, <#item2 as ParamGuard>::Error> {
108+
self.0 = self.0.amend(path, val)?;
103109

104110
Ok(self)
105111
}

0 commit comments

Comments
 (0)