Skip to content

Commit 73d770f

Browse files
authored
Merge pull request #58 from marshallpierce/mp/string-type
Add String data type support
2 parents b2291e4 + a98d630 commit 73d770f

File tree

6 files changed

+232
-37
lines changed

6 files changed

+232
-37
lines changed

onnxruntime-sys/build.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,10 @@ fn generate_bindings(include_dir: &Path) {
105105
.expect("Couldn't write bindings!");
106106
}
107107

108-
fn download<P: AsRef<Path>>(source_url: &str, target_file: P) {
108+
fn download<P>(source_url: &str, target_file: P)
109+
where
110+
P: AsRef<Path>,
111+
{
109112
let resp = ureq::get(source_url)
110113
.timeout_connect(1_000) // 1 second
111114
.timeout(std::time::Duration::from_secs(300))
+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
//! Display the input and output structure of an ONNX model.
2+
use onnxruntime::environment;
3+
use std::error::Error;
4+
5+
fn main() -> Result<(), Box<dyn Error>> {
6+
// provide path to .onnx model on disk
7+
let path = std::env::args()
8+
.skip(1)
9+
.next()
10+
.expect("Must provide an .onnx file as the first arg");
11+
12+
let environment = environment::Environment::builder()
13+
.with_name("onnx metadata")
14+
.with_log_level(onnxruntime::LoggingLevel::Verbose)
15+
.build()?;
16+
17+
let session = environment
18+
.new_session_builder()?
19+
.with_optimization_level(onnxruntime::GraphOptimizationLevel::Basic)?
20+
.with_model_from_file(path)?;
21+
22+
println!("Inputs:");
23+
for (index, input) in session.inputs.iter().enumerate() {
24+
println!(
25+
" {}:\n name = {}\n type = {:?}\n dimensions = {:?}",
26+
index, input.name, input.input_type, input.dimensions
27+
)
28+
}
29+
30+
println!("Outputs:");
31+
for (index, output) in session.outputs.iter().enumerate() {
32+
println!(
33+
" {}:\n name = {}\n type = {:?}\n dimensions = {:?}",
34+
index, output.name, output.output_type, output.dimensions
35+
);
36+
}
37+
38+
Ok(())
39+
}

onnxruntime/src/error.rs

+14
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,15 @@ pub enum OrtError {
5656
/// Error occurred when creating CPU memory information
5757
#[error("Failed to get dimensions: {0}")]
5858
CreateCpuMemoryInfo(OrtApiError),
59+
/// Error occurred when creating ONNX tensor
60+
#[error("Failed to create tensor: {0}")]
61+
CreateTensor(OrtApiError),
5962
/// Error occurred when creating ONNX tensor with specific data
6063
#[error("Failed to create tensor with data: {0}")]
6164
CreateTensorWithData(OrtApiError),
65+
/// Error occurred when filling a tensor with string data
66+
#[error("Failed to fill string tensor: {0}")]
67+
FillStringTensor(OrtApiError),
6268
/// Error occurred when checking if ONNX tensor was properly initialized
6369
#[error("Failed to check if tensor: {0}")]
6470
IsTensor(OrtApiError),
@@ -183,3 +189,11 @@ pub(crate) fn status_to_result(
183189
let status_wrapper: OrtStatusWrapper = status.into();
184190
status_wrapper.into()
185191
}
192+
193+
/// A wrapper around a function on OrtApi that maps the status code into [OrtApiError]
194+
pub(crate) unsafe fn call_ort<F>(mut f: F) -> std::result::Result<(), OrtApiError>
195+
where
196+
F: FnMut(sys::OrtApi) -> *const sys::OrtStatus,
197+
{
198+
status_to_result(f(g_ort()))
199+
}

onnxruntime/src/lib.rs

+46-9
Original file line numberDiff line numberDiff line change
@@ -343,8 +343,8 @@ pub enum TensorElementDataType {
343343
Int32 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 as OnnxEnumInt,
344344
/// Signed 64-bit int, equivalent to Rust's `i64`
345345
Int64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 as OnnxEnumInt,
346-
// /// String, equivalent to Rust's `String`
347-
// String = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING as OnnxEnumInt,
346+
/// String, equivalent to Rust's `String`
347+
String = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING as OnnxEnumInt,
348348
// /// Boolean, equivalent to Rust's `bool`
349349
// Bool = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL as OnnxEnumInt,
350350
// /// 16-bit floating point, equivalent to Rust's `f16`
@@ -374,9 +374,7 @@ impl Into<sys::ONNXTensorElementDataType> for TensorElementDataType {
374374
Int16 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16,
375375
Int32 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32,
376376
Int64 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
377-
// String => {
378-
// sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING
379-
// }
377+
String => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING,
380378
// Bool => {
381379
// sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL
382380
// }
@@ -402,15 +400,22 @@ impl Into<sys::ONNXTensorElementDataType> for TensorElementDataType {
402400
/// Trait used to map Rust types (for example `f32`) to ONNX types (for example `Float`)
403401
pub trait TypeToTensorElementDataType {
404402
/// Return the ONNX type for a Rust type
405-
fn tensor_element_data_type() -> sys::ONNXTensorElementDataType;
403+
fn tensor_element_data_type() -> TensorElementDataType;
404+
405+
/// If the type is `String`, returns `Some` with utf8 contents, else `None`.
406+
fn try_utf8_bytes(&self) -> Option<&[u8]>;
406407
}
407408

408409
macro_rules! impl_type_trait {
409410
($type_:ty, $variant:ident) => {
410411
impl TypeToTensorElementDataType for $type_ {
411-
fn tensor_element_data_type() -> sys::ONNXTensorElementDataType {
412+
fn tensor_element_data_type() -> TensorElementDataType {
412413
// unsafe { std::mem::transmute(TensorElementDataType::$variant) }
413-
TensorElementDataType::$variant.into()
414+
TensorElementDataType::$variant
415+
}
416+
417+
fn try_utf8_bytes(&self) -> Option<&[u8]> {
418+
None
414419
}
415420
}
416421
};
@@ -423,7 +428,6 @@ impl_type_trait!(u16, Uint16);
423428
impl_type_trait!(i16, Int16);
424429
impl_type_trait!(i32, Int32);
425430
impl_type_trait!(i64, Int64);
426-
// impl_type_trait!(String, String);
427431
// impl_type_trait!(bool, Bool);
428432
// impl_type_trait!(f16, Float16);
429433
impl_type_trait!(f64, Double);
@@ -433,6 +437,39 @@ impl_type_trait!(u64, Uint64);
433437
// impl_type_trait!(, Complex128);
434438
// impl_type_trait!(, Bfloat16);
435439

440+
/// Adapter for common Rust string types to Onnx strings.
441+
///
442+
/// It should be easy to use both `String` and `&str` as [TensorElementDataType::String] data, but
443+
/// we can't define an automatic implementation for anything that implements `AsRef<str>` as it
444+
/// would conflict with the implementations of [TypeToTensorElementDataType] for primitive numeric
445+
/// types (which might implement `AsRef<str>` at some point in the future).
446+
pub trait Utf8Data {
447+
/// Returns the utf8 contents.
448+
fn utf8_bytes(&self) -> &[u8];
449+
}
450+
451+
impl Utf8Data for String {
452+
fn utf8_bytes(&self) -> &[u8] {
453+
self.as_bytes()
454+
}
455+
}
456+
457+
impl<'a> Utf8Data for &'a str {
458+
fn utf8_bytes(&self) -> &[u8] {
459+
self.as_bytes()
460+
}
461+
}
462+
463+
impl<T: Utf8Data> TypeToTensorElementDataType for T {
464+
fn tensor_element_data_type() -> TensorElementDataType {
465+
TensorElementDataType::String
466+
}
467+
468+
fn try_utf8_bytes(&self) -> Option<&[u8]> {
469+
Some(self.utf8_bytes())
470+
}
471+
}
472+
436473
/// Allocator type
437474
#[derive(Debug, Clone)]
438475
#[repr(i32)]

onnxruntime/src/session.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,9 @@ impl<'a> Session<'a> {
411411
// The C API expects pointers for the arrays (pointers to C-arrays)
412412
let input_ort_tensors: Vec<OrtTensor<TIn, D>> = input_arrays
413413
.into_iter()
414-
.map(|input_array| OrtTensor::from_array(&self.memory_info, input_array))
414+
.map(|input_array| {
415+
OrtTensor::from_array(&self.memory_info, self.allocator_ptr, input_array)
416+
})
415417
.collect::<Result<Vec<OrtTensor<TIn, D>>>>()?;
416418
let input_ort_values: Vec<*const sys::OrtValue> = input_ort_tensors
417419
.iter()

0 commit comments

Comments
 (0)