diff --git a/onnxruntime/examples/issue22.rs b/onnxruntime/examples/issue22.rs index b2879b91..9dbd5d5b 100644 --- a/onnxruntime/examples/issue22.rs +++ b/onnxruntime/examples/issue22.rs @@ -34,7 +34,12 @@ fn main() { let input_ids = Array2::<i64>::from_shape_vec((1, 3), vec![1, 2, 3]).unwrap(); let attention_mask = Array2::<i64>::from_shape_vec((1, 3), vec![1, 1, 1]).unwrap(); - let outputs: Vec<OrtOwnedTensor<f32, _>> = - session.run(vec![input_ids, attention_mask]).unwrap(); + let outputs: Vec<OrtOwnedTensor<f32, _>> = session + .run(vec![input_ids, attention_mask]) + .unwrap() + .into_iter() + .map(|dyn_tensor| dyn_tensor.try_extract()) + .collect::<Result<_, _>>() + .unwrap(); print!("outputs: {:#?}", outputs); } diff --git a/onnxruntime/examples/sample.rs b/onnxruntime/examples/sample.rs index d16d08da..3fbc2670 100644 --- a/onnxruntime/examples/sample.rs +++ b/onnxruntime/examples/sample.rs @@ -1,8 +1,10 @@ #![forbid(unsafe_code)] use onnxruntime::{ - environment::Environment, ndarray::Array, tensor::OrtOwnedTensor, GraphOptimizationLevel, - LoggingLevel, + environment::Environment, + ndarray::Array, + tensor::{DynOrtTensor, OrtOwnedTensor}, + GraphOptimizationLevel, LoggingLevel, }; use tracing::Level; use tracing_subscriber::FmtSubscriber; @@ -61,11 +63,12 @@ fn run() -> Result<(), Error> { .unwrap(); let input_tensor_values = vec![array]; - let outputs: Vec<OrtOwnedTensor<f32, _>> = session.run(input_tensor_values)?; + let outputs: Vec<DynOrtTensor<_>> = session.run(input_tensor_values)?; - assert_eq!(outputs[0].shape(), output0_shape.as_slice()); + let output: OrtOwnedTensor<f32, _> = outputs[0].try_extract().unwrap(); + assert_eq!(output.shape(), output0_shape.as_slice()); for i in 0..5 { - println!("Score for class [{}] = {}", i, outputs[0][[0, i, 0, 0]]); + println!("Score for class [{}] = {}", i, output[[0, i, 0, 0]]); } Ok(()) diff --git a/onnxruntime/src/lib.rs b/onnxruntime/src/lib.rs index 0d575b5e..6ae7c333 100644 --- a/onnxruntime/src/lib.rs +++ b/onnxruntime/src/lib.rs @@ -104,7 +104,10 @@ to download. //! let array = ndarray::Array::linspace(0.0_f32, 1.0, 100); //! // Multiple inputs and outputs are possible //! let input_tensor = vec![array]; -//! let outputs: Vec<OrtOwnedTensor<f32,_>> = session.run(input_tensor)?; +//! let outputs: Vec<OrtOwnedTensor<f32, _>> = session.run(input_tensor)? +//! .into_iter() +//! .map(|dyn_tensor| dyn_tensor.try_extract()) +//! .collect::<Result<_, _>>()?; //! # Ok(()) //! # } //! ``` @@ -115,7 +118,10 @@ to download. //! See the [`sample.rs`](https://github.com/nbigaouette/onnxruntime-rs/blob/master/onnxruntime/examples/sample.rs) //! example for more details. -use std::sync::{atomic::AtomicPtr, Arc, Mutex}; +use std::{ + ffi, ptr, + sync::{atomic::AtomicPtr, Arc, Mutex}, +}; use lazy_static::lazy_static; @@ -142,7 +148,7 @@ lazy_static! { // } as *mut sys::OrtApi))); static ref G_ORT_API: Arc<Mutex<AtomicPtr<sys::OrtApi>>> = { let base: *const sys::OrtApiBase = unsafe { sys::OrtGetApiBase() }; - assert_ne!(base, std::ptr::null()); + assert_ne!(base, ptr::null()); let get_api: unsafe extern "C" fn(u32) -> *const onnxruntime_sys::OrtApi = unsafe { (*base).GetApi.unwrap() }; let api: *const sys::OrtApi = unsafe { get_api(sys::ORT_API_VERSION) }; @@ -157,13 +163,13 @@ fn g_ort() -> sys::OrtApi { let api_ref_mut: &mut *mut sys::OrtApi = api_ref.get_mut(); let api_ptr_mut: *mut sys::OrtApi = *api_ref_mut; - assert_ne!(api_ptr_mut, std::ptr::null_mut()); + assert_ne!(api_ptr_mut, ptr::null_mut()); unsafe { *api_ptr_mut } } fn char_p_to_string(raw: *const i8) -> Result<String> { - let c_string = unsafe { std::ffi::CStr::from_ptr(raw as *mut i8).to_owned() }; + let c_string = unsafe { ffi::CStr::from_ptr(raw as *mut i8).to_owned() }; match c_string.into_string() { Ok(string) => Ok(string), @@ -176,7 +182,7 @@ mod onnxruntime { //! Module containing a custom logger, used to catch the runtime's own logging and send it //! to Rust's tracing logging instead. - use std::ffi::CStr; + use std::{ffi, ffi::CStr, ptr}; use tracing::{debug, error, info, span, trace, warn, Level}; use onnxruntime_sys as sys; @@ -212,7 +218,7 @@ mod onnxruntime { /// Callback from C that will handle the logging, forwarding the runtime's logs to the tracing crate. pub(crate) extern "C" fn custom_logger( - _params: *mut std::ffi::c_void, + _params: *mut ffi::c_void, severity: sys::OrtLoggingLevel, category: *const i8, logid: *const i8, @@ -227,16 +233,16 @@ mod onnxruntime { sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL => Level::ERROR, }; - assert_ne!(category, std::ptr::null()); + assert_ne!(category, ptr::null()); let category = unsafe { CStr::from_ptr(category) }; - assert_ne!(code_location, std::ptr::null()); + assert_ne!(code_location, ptr::null()); let code_location = unsafe { CStr::from_ptr(code_location) } .to_str() .unwrap_or("unknown"); - assert_ne!(message, std::ptr::null()); + assert_ne!(message, ptr::null()); let message = unsafe { CStr::from_ptr(message) }; - assert_ne!(logid, std::ptr::null()); + assert_ne!(logid, ptr::null()); let logid = unsafe { CStr::from_ptr(logid) }; // Parse the code location @@ -322,154 +328,6 @@ impl Into<sys::GraphOptimizationLevel> for GraphOptimizationLevel { } } -// FIXME: Use https://docs.rs/bindgen/0.54.1/bindgen/struct.Builder.html#method.rustified_enum -// FIXME: Add tests to cover the commented out types -/// Enum mapping ONNX Runtime's supported tensor types -#[derive(Debug)] -#[cfg_attr(not(windows), repr(u32))] -#[cfg_attr(windows, repr(i32))] -pub enum TensorElementDataType { - /// 32-bit floating point, equivalent to Rust's `f32` - Float = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT as OnnxEnumInt, - /// Unsigned 8-bit int, equivalent to Rust's `u8` - Uint8 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 as OnnxEnumInt, - /// Signed 8-bit int, equivalent to Rust's `i8` - Int8 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 as OnnxEnumInt, - /// Unsigned 16-bit int, equivalent to Rust's `u16` - Uint16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 as OnnxEnumInt, - /// Signed 16-bit int, equivalent to Rust's `i16` - Int16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 as OnnxEnumInt, - /// Signed 32-bit int, equivalent to Rust's `i32` - Int32 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 as OnnxEnumInt, - /// Signed 64-bit int, equivalent to Rust's `i64` - Int64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 as OnnxEnumInt, - /// String, equivalent to Rust's `String` - String = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING as OnnxEnumInt, - // /// Boolean, equivalent to Rust's `bool` - // Bool = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL as OnnxEnumInt, - // /// 16-bit floating point, equivalent to Rust's `f16` - // Float16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 as OnnxEnumInt, - /// 64-bit floating point, equivalent to Rust's `f64` - Double = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE as OnnxEnumInt, - /// Unsigned 32-bit int, equivalent to Rust's `u32` - Uint32 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 as OnnxEnumInt, - /// Unsigned 64-bit int, equivalent to Rust's `u64` - Uint64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 as OnnxEnumInt, - // /// Complex 64-bit floating point, equivalent to Rust's `???` - // Complex64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 as OnnxEnumInt, - // /// Complex 128-bit floating point, equivalent to Rust's `???` - // Complex128 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 as OnnxEnumInt, - // /// Brain 16-bit floating point - // Bfloat16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 as OnnxEnumInt, -} - -impl Into<sys::ONNXTensorElementDataType> for TensorElementDataType { - fn into(self) -> sys::ONNXTensorElementDataType { - use TensorElementDataType::*; - match self { - Float => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, - Uint8 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, - Int8 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, - Uint16 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, - Int16 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, - Int32 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, - Int64 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, - String => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, - // Bool => { - // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL - // } - // Float16 => { - // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 - // } - Double => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, - Uint32 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, - Uint64 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, - // Complex64 => { - // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 - // } - // Complex128 => { - // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 - // } - // Bfloat16 => { - // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 - // } - } - } -} - -/// Trait used to map Rust types (for example `f32`) to ONNX types (for example `Float`) -pub trait TypeToTensorElementDataType { - /// Return the ONNX type for a Rust type - fn tensor_element_data_type() -> TensorElementDataType; - - /// If the type is `String`, returns `Some` with utf8 contents, else `None`. - fn try_utf8_bytes(&self) -> Option<&[u8]>; -} - -macro_rules! impl_type_trait { - ($type_:ty, $variant:ident) => { - impl TypeToTensorElementDataType for $type_ { - fn tensor_element_data_type() -> TensorElementDataType { - // unsafe { std::mem::transmute(TensorElementDataType::$variant) } - TensorElementDataType::$variant - } - - fn try_utf8_bytes(&self) -> Option<&[u8]> { - None - } - } - }; -} - -impl_type_trait!(f32, Float); -impl_type_trait!(u8, Uint8); -impl_type_trait!(i8, Int8); -impl_type_trait!(u16, Uint16); -impl_type_trait!(i16, Int16); -impl_type_trait!(i32, Int32); -impl_type_trait!(i64, Int64); -// impl_type_trait!(bool, Bool); -// impl_type_trait!(f16, Float16); -impl_type_trait!(f64, Double); -impl_type_trait!(u32, Uint32); -impl_type_trait!(u64, Uint64); -// impl_type_trait!(, Complex64); -// impl_type_trait!(, Complex128); -// impl_type_trait!(, Bfloat16); - -/// Adapter for common Rust string types to Onnx strings. -/// -/// It should be easy to use both `String` and `&str` as [TensorElementDataType::String] data, but -/// we can't define an automatic implementation for anything that implements `AsRef<str>` as it -/// would conflict with the implementations of [TypeToTensorElementDataType] for primitive numeric -/// types (which might implement `AsRef<str>` at some point in the future). -pub trait Utf8Data { - /// Returns the utf8 contents. - fn utf8_bytes(&self) -> &[u8]; -} - -impl Utf8Data for String { - fn utf8_bytes(&self) -> &[u8] { - self.as_bytes() - } -} - -impl<'a> Utf8Data for &'a str { - fn utf8_bytes(&self) -> &[u8] { - self.as_bytes() - } -} - -impl<T: Utf8Data> TypeToTensorElementDataType for T { - fn tensor_element_data_type() -> TensorElementDataType { - TensorElementDataType::String - } - - fn try_utf8_bytes(&self) -> Option<&[u8]> { - Some(self.utf8_bytes()) - } -} - /// Allocator type #[derive(Debug, Clone)] #[repr(i32)] @@ -524,7 +382,7 @@ mod test { #[test] fn test_char_p_to_string() { - let s = std::ffi::CString::new("foo").unwrap(); + let s = ffi::CString::new("foo").unwrap(); let ptr = s.as_c_str().as_ptr(); assert_eq!("foo", char_p_to_string(ptr).unwrap()); } diff --git a/onnxruntime/src/session.rs b/onnxruntime/src/session.rs index 04f9cf1c..232d188d 100644 --- a/onnxruntime/src/session.rs +++ b/onnxruntime/src/session.rs @@ -18,15 +18,11 @@ use onnxruntime_sys as sys; use crate::{ char_p_to_string, environment::Environment, - error::{status_to_result, NonMatchingDimensionsError, OrtError, Result}, + error::{call_ort, status_to_result, NonMatchingDimensionsError, OrtError, Result}, g_ort, memory::MemoryInfo, - tensor::{ - ort_owned_tensor::{OrtOwnedTensor, OrtOwnedTensorExtractor}, - OrtTensor, - }, - AllocatorType, GraphOptimizationLevel, MemType, TensorElementDataType, - TypeToTensorElementDataType, + tensor::{DynOrtTensor, OrtTensor, TensorElementDataType, TypeToTensorElementDataType}, + AllocatorType, GraphOptimizationLevel, MemType, }; #[cfg(feature = "model-fetching")] @@ -365,13 +361,12 @@ impl<'a> Session<'a> { /// /// Note that ONNX models can have multiple inputs; a `Vec<_>` is thus /// used for the input data here. - pub fn run<'s, 't, 'm, TIn, TOut, D>( + pub fn run<'s, 't, 'm, TIn, D>( &'s mut self, input_arrays: Vec<Array<TIn, D>>, - ) -> Result<Vec<OrtOwnedTensor<'t, 'm, TOut, ndarray::IxDyn>>> + ) -> Result<Vec<DynOrtTensor<'m, ndarray::IxDyn>>> where TIn: TypeToTensorElementDataType + Debug + Clone, - TOut: TypeToTensorElementDataType + Debug + Clone, D: ndarray::Dimension, 'm: 't, // 'm outlives 't (memory info outlives tensor) 's: 'm, // 's outlives 'm (session outlives memory info) @@ -405,7 +400,7 @@ impl<'a> Session<'a> { .map(|n| n.as_ptr() as *const i8) .collect(); - let mut output_tensor_extractors_ptrs: Vec<*mut sys::OrtValue> = + let mut output_tensor_ptrs: Vec<*mut sys::OrtValue> = vec![std::ptr::null_mut(); self.outputs.len()]; // The C API expects pointers for the arrays (pointers to C-arrays) @@ -431,30 +426,33 @@ impl<'a> Session<'a> { input_ort_values.len() as u64, // C API expects a u64, not isize output_names_ptr.as_ptr(), output_names_ptr.len() as u64, // C API expects a u64, not isize - output_tensor_extractors_ptrs.as_mut_ptr(), + output_tensor_ptrs.as_mut_ptr(), ) }; status_to_result(status).map_err(OrtError::Run)?; let memory_info_ref = &self.memory_info; - let outputs: Result<Vec<OrtOwnedTensor<TOut, ndarray::Dim<ndarray::IxDynImpl>>>> = - output_tensor_extractors_ptrs + let outputs: Result<Vec<DynOrtTensor<ndarray::Dim<ndarray::IxDynImpl>>>> = + output_tensor_ptrs .into_iter() - .map(|ptr| { - let mut tensor_info_ptr: *mut sys::OrtTensorTypeAndShapeInfo = - std::ptr::null_mut(); - let status = unsafe { - g_ort().GetTensorTypeAndShape.unwrap()(ptr, &mut tensor_info_ptr as _) - }; - status_to_result(status).map_err(OrtError::GetTensorTypeAndShape)?; - let dims = unsafe { get_tensor_dimensions(tensor_info_ptr) }; - unsafe { g_ort().ReleaseTensorTypeAndShapeInfo.unwrap()(tensor_info_ptr) }; - let dims: Vec<_> = dims?.iter().map(|&n| n as usize).collect(); - - let mut output_tensor_extractor = - OrtOwnedTensorExtractor::new(memory_info_ref, ndarray::IxDyn(&dims)); - output_tensor_extractor.tensor_ptr = ptr; - output_tensor_extractor.extract::<TOut>() + .map(|tensor_ptr| { + let (dims, data_type) = unsafe { + call_with_tensor_info(tensor_ptr, |tensor_info_ptr| { + get_tensor_dimensions(tensor_info_ptr) + .map(|dims| dims.iter().map(|&n| n as usize).collect::<Vec<_>>()) + .and_then(|dims| { + extract_data_type(tensor_info_ptr) + .map(|data_type| (dims, data_type)) + }) + }) + }?; + + Ok(DynOrtTensor::new( + tensor_ptr, + memory_info_ref, + ndarray::IxDyn(&dims), + data_type, + )) }) .collect(); @@ -554,25 +552,60 @@ unsafe fn get_tensor_dimensions( tensor_info_ptr: *const sys::OrtTensorTypeAndShapeInfo, ) -> Result<Vec<i64>> { let mut num_dims = 0; - let status = g_ort().GetDimensionsCount.unwrap()(tensor_info_ptr, &mut num_dims); - status_to_result(status).map_err(OrtError::GetDimensionsCount)?; + call_ort(|ort| ort.GetDimensionsCount.unwrap()(tensor_info_ptr, &mut num_dims)) + .map_err(OrtError::GetDimensionsCount)?; assert_ne!(num_dims, 0); let mut node_dims: Vec<i64> = vec![0; num_dims as usize]; - let status = g_ort().GetDimensions.unwrap()( - tensor_info_ptr, - node_dims.as_mut_ptr(), // FIXME: UB? - num_dims, - ); - status_to_result(status).map_err(OrtError::GetDimensions)?; + call_ort(|ort| { + ort.GetDimensions.unwrap()( + tensor_info_ptr, + node_dims.as_mut_ptr(), // FIXME: UB? + num_dims, + ) + }) + .map_err(OrtError::GetDimensions)?; Ok(node_dims) } +unsafe fn extract_data_type( + tensor_info_ptr: *const sys::OrtTensorTypeAndShapeInfo, +) -> Result<TensorElementDataType> { + let mut type_sys = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + call_ort(|ort| ort.GetTensorElementType.unwrap()(tensor_info_ptr, &mut type_sys)) + .map_err(OrtError::TensorElementType)?; + assert_ne!( + type_sys, + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED + ); + // This transmute should be safe since its value is read from GetTensorElementType which we must trust. + Ok(std::mem::transmute(type_sys)) +} + +/// Calls the provided closure with the result of `GetTensorTypeAndShape`, deallocating the +/// resulting `*OrtTensorTypeAndShapeInfo` before returning. +unsafe fn call_with_tensor_info<F, T>(tensor_ptr: *const sys::OrtValue, mut f: F) -> Result<T> +where + F: FnMut(*const sys::OrtTensorTypeAndShapeInfo) -> Result<T>, +{ + let mut tensor_info_ptr: *mut sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); + call_ort(|ort| ort.GetTensorTypeAndShape.unwrap()(tensor_ptr, &mut tensor_info_ptr as _)) + .map_err(OrtError::GetTensorTypeAndShape)?; + + let res = f(tensor_info_ptr); + + // no return code, so no errors to check for + g_ort().ReleaseTensorTypeAndShapeInfo.unwrap()(tensor_info_ptr); + + res +} + /// This module contains dangerous functions working on raw pointers. /// Those functions are only to be used from inside the /// `SessionBuilder::with_model_from_file()` method. mod dangerous { use super::*; + use crate::tensor::TensorElementDataType; pub(super) fn extract_inputs_count(session_ptr: *mut sys::OrtSession) -> Result<u64> { let f = g_ort().SessionGetInputCount.unwrap(); @@ -689,16 +722,7 @@ mod dangerous { status_to_result(status).map_err(OrtError::CastTypeInfoToTensorInfo)?; assert_ne!(tensor_info_ptr, std::ptr::null_mut()); - let mut type_sys = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - let status = - unsafe { g_ort().GetTensorElementType.unwrap()(tensor_info_ptr, &mut type_sys) }; - status_to_result(status).map_err(OrtError::TensorElementType)?; - assert_ne!( - type_sys, - sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED - ); - // This transmute should be safe since its value is read from GetTensorElementType which we must trust. - let io_type: TensorElementDataType = unsafe { std::mem::transmute(type_sys) }; + let io_type: TensorElementDataType = unsafe { extract_data_type(tensor_info_ptr)? }; // info!("{} : type={}", i, type_); diff --git a/onnxruntime/src/tensor.rs b/onnxruntime/src/tensor.rs index 92404842..df85e1ed 100644 --- a/onnxruntime/src/tensor.rs +++ b/onnxruntime/src/tensor.rs @@ -27,5 +27,231 @@ pub mod ndarray_tensor; pub mod ort_owned_tensor; pub mod ort_tensor; -pub use ort_owned_tensor::OrtOwnedTensor; +pub use ort_owned_tensor::{DynOrtTensor, OrtOwnedTensor}; pub use ort_tensor::OrtTensor; + +use crate::{OrtError, Result}; +use onnxruntime_sys::{self as sys, OnnxEnumInt}; +use std::{fmt, ptr}; + +// FIXME: Use https://docs.rs/bindgen/0.54.1/bindgen/struct.Builder.html#method.rustified_enum +// FIXME: Add tests to cover the commented out types +/// Enum mapping ONNX Runtime's supported tensor types +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(not(windows), repr(u32))] +#[cfg_attr(windows, repr(i32))] +pub enum TensorElementDataType { + /// 32-bit floating point, equivalent to Rust's `f32` + Float = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT as OnnxEnumInt, + /// Unsigned 8-bit int, equivalent to Rust's `u8` + Uint8 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 as OnnxEnumInt, + /// Signed 8-bit int, equivalent to Rust's `i8` + Int8 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 as OnnxEnumInt, + /// Unsigned 16-bit int, equivalent to Rust's `u16` + Uint16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 as OnnxEnumInt, + /// Signed 16-bit int, equivalent to Rust's `i16` + Int16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 as OnnxEnumInt, + /// Signed 32-bit int, equivalent to Rust's `i32` + Int32 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 as OnnxEnumInt, + /// Signed 64-bit int, equivalent to Rust's `i64` + Int64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 as OnnxEnumInt, + /// String, equivalent to Rust's `String` + String = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING as OnnxEnumInt, + // /// Boolean, equivalent to Rust's `bool` + // Bool = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL as OnnxEnumInt, + // /// 16-bit floating point, equivalent to Rust's `f16` + // Float16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 as OnnxEnumInt, + /// 64-bit floating point, equivalent to Rust's `f64` + Double = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE as OnnxEnumInt, + /// Unsigned 32-bit int, equivalent to Rust's `u32` + Uint32 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 as OnnxEnumInt, + /// Unsigned 64-bit int, equivalent to Rust's `u64` + Uint64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 as OnnxEnumInt, + // /// Complex 64-bit floating point, equivalent to Rust's `???` + // Complex64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 as OnnxEnumInt, + // /// Complex 128-bit floating point, equivalent to Rust's `???` + // Complex128 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 as OnnxEnumInt, + // /// Brain 16-bit floating point + // Bfloat16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 as OnnxEnumInt, +} + +impl Into<sys::ONNXTensorElementDataType> for TensorElementDataType { + fn into(self) -> sys::ONNXTensorElementDataType { + use TensorElementDataType::*; + match self { + Float => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + Uint8 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, + Int8 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, + Uint16 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, + Int16 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, + Int32 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, + Int64 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, + String => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, + // Bool => { + // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL + // } + // Float16 => { + // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 + // } + Double => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, + Uint32 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, + Uint64 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, + // Complex64 => { + // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 + // } + // Complex128 => { + // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 + // } + // Bfloat16 => { + // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 + // } + } + } +} + +/// Trait used to map Rust types (for example `f32`) to ONNX types (for example `Float`) +pub trait TypeToTensorElementDataType { + /// Return the ONNX type for a Rust type + fn tensor_element_data_type() -> TensorElementDataType; + + /// If the type is `String`, returns `Some` with utf8 contents, else `None`. + fn try_utf8_bytes(&self) -> Option<&[u8]>; +} + +macro_rules! impl_prim_type_to_ort_trait { + ($type_:ty, $variant:ident) => { + impl TypeToTensorElementDataType for $type_ { + fn tensor_element_data_type() -> TensorElementDataType { + // unsafe { std::mem::transmute(TensorElementDataType::$variant) } + TensorElementDataType::$variant + } + + fn try_utf8_bytes(&self) -> Option<&[u8]> { + None + } + } + }; +} + +impl_prim_type_to_ort_trait!(f32, Float); +impl_prim_type_to_ort_trait!(u8, Uint8); +impl_prim_type_to_ort_trait!(i8, Int8); +impl_prim_type_to_ort_trait!(u16, Uint16); +impl_prim_type_to_ort_trait!(i16, Int16); +impl_prim_type_to_ort_trait!(i32, Int32); +impl_prim_type_to_ort_trait!(i64, Int64); +// impl_type_trait!(bool, Bool); +// impl_type_trait!(f16, Float16); +impl_prim_type_to_ort_trait!(f64, Double); +impl_prim_type_to_ort_trait!(u32, Uint32); +impl_prim_type_to_ort_trait!(u64, Uint64); +// impl_type_trait!(, Complex64); +// impl_type_trait!(, Complex128); +// impl_type_trait!(, Bfloat16); + +/// Adapter for common Rust string types to Onnx strings. +/// +/// It should be easy to use both `String` and `&str` as [TensorElementDataType::String] data, but +/// we can't define an automatic implementation for anything that implements `AsRef<str>` as it +/// would conflict with the implementations of [TypeToTensorElementDataType] for primitive numeric +/// types (which might implement `AsRef<str>` at some point in the future). +pub trait Utf8Data { + /// Returns the utf8 contents. + fn utf8_bytes(&self) -> &[u8]; +} + +impl Utf8Data for String { + fn utf8_bytes(&self) -> &[u8] { + self.as_bytes() + } +} + +impl<'a> Utf8Data for &'a str { + fn utf8_bytes(&self) -> &[u8] { + self.as_bytes() + } +} + +impl<T: Utf8Data> TypeToTensorElementDataType for T { + fn tensor_element_data_type() -> TensorElementDataType { + TensorElementDataType::String + } + + fn try_utf8_bytes(&self) -> Option<&[u8]> { + Some(self.utf8_bytes()) + } +} + +/// Trait used to map onnxruntime types to Rust types +pub trait TensorDataToType: Sized + fmt::Debug { + /// The tensor element type that this type can extract from + fn tensor_element_data_type() -> TensorElementDataType; + + /// Extract an `ArrayView` from the ort-owned tensor. + fn extract_array<'t, D>( + shape: D, + tensor: *mut sys::OrtValue, + ) -> Result<ndarray::ArrayView<'t, Self, D>> + where + D: ndarray::Dimension; +} + +/// Implements `OwnedTensorDataToType` for primitives, which can use `GetTensorMutableData` +macro_rules! impl_prim_type_from_ort_trait { + ($type_:ty, $variant:ident) => { + impl TensorDataToType for $type_ { + fn tensor_element_data_type() -> TensorElementDataType { + TensorElementDataType::$variant + } + + fn extract_array<'t, D>( + shape: D, + tensor: *mut sys::OrtValue, + ) -> Result<ndarray::ArrayView<'t, Self, D>> + where + D: ndarray::Dimension, + { + extract_primitive_array(shape, tensor) + } + } + }; +} + +/// Construct an [ndarray::ArrayView] over an Ort tensor. +/// +/// Only to be used on types whose Rust in-memory representation matches Ort's (e.g. primitive +/// numeric types like u32). +fn extract_primitive_array<'t, D, T: TensorDataToType>( + shape: D, + tensor: *mut sys::OrtValue, +) -> Result<ndarray::ArrayView<'t, T, D>> +where + D: ndarray::Dimension, +{ + // Get pointer to output tensor float values + let mut output_array_ptr: *mut T = ptr::null_mut(); + let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr; + let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = + output_array_ptr_ptr as *mut *mut std::ffi::c_void; + unsafe { + crate::error::call_ort(|ort| { + ort.GetTensorMutableData.unwrap()(tensor, output_array_ptr_ptr_void) + }) + } + .map_err(OrtError::GetTensorMutableData)?; + assert_ne!(output_array_ptr, ptr::null_mut()); + + let array_view = unsafe { ndarray::ArrayView::from_shape_ptr(shape, output_array_ptr) }; + Ok(array_view) +} + +impl_prim_type_from_ort_trait!(f32, Float); +impl_prim_type_from_ort_trait!(u8, Uint8); +impl_prim_type_from_ort_trait!(i8, Int8); +impl_prim_type_from_ort_trait!(u16, Uint16); +impl_prim_type_from_ort_trait!(i16, Int16); +impl_prim_type_from_ort_trait!(i32, Int32); +impl_prim_type_from_ort_trait!(i64, Int64); +impl_prim_type_from_ort_trait!(f64, Double); +impl_prim_type_from_ort_trait!(u32, Uint32); +impl_prim_type_from_ort_trait!(u64, Uint64); diff --git a/onnxruntime/src/tensor/ort_owned_tensor.rs b/onnxruntime/src/tensor/ort_owned_tensor.rs index 161fe105..48f48308 100644 --- a/onnxruntime/src/tensor/ort_owned_tensor.rs +++ b/onnxruntime/src/tensor/ort_owned_tensor.rs @@ -1,17 +1,123 @@ //! Module containing tensor with memory owned by the ONNX Runtime -use std::{fmt::Debug, ops::Deref}; +use std::{fmt::Debug, ops::Deref, ptr, rc, result}; use ndarray::{Array, ArrayView}; +use thiserror::Error; use tracing::debug; use onnxruntime_sys as sys; use crate::{ - error::status_to_result, g_ort, memory::MemoryInfo, tensor::ndarray_tensor::NdArrayTensor, - OrtError, Result, TypeToTensorElementDataType, + error::call_ort, + g_ort, + memory::MemoryInfo, + tensor::{ndarray_tensor::NdArrayTensor, TensorDataToType, TensorElementDataType}, + OrtError, }; +/// Errors that can occur while extracting a tensor from ort output. +#[derive(Error, Debug)] +pub enum TensorExtractError { + /// The user tried to extract the wrong type of tensor from the underlying data + #[error( + "Data type mismatch: was {:?}, tried to convert to {:?}", + actual, + requested + )] + DataTypeMismatch { + /// The actual type of the ort output + actual: TensorElementDataType, + /// The type corresponding to the attempted conversion into a Rust type, not equal to `actual` + requested: TensorElementDataType, + }, + /// An onnxruntime error occurred + #[error("Onnxruntime error: {:?}", 0)] + OrtError(#[from] OrtError), +} + +/// A wrapper around a tensor produced by onnxruntime inference. +/// +/// Since different outputs for the same model can have different types, this type is used to allow +/// the user to dynamically query each output's type and extract the appropriate tensor type with +/// [try_extract]. +#[derive(Debug)] +pub struct DynOrtTensor<'m, D> +where + D: ndarray::Dimension, +{ + tensor_ptr_holder: rc::Rc<TensorPointerDropper>, + memory_info: &'m MemoryInfo, + shape: D, + data_type: TensorElementDataType, +} + +impl<'m, D> DynOrtTensor<'m, D> +where + D: ndarray::Dimension, +{ + pub(crate) fn new( + tensor_ptr: *mut sys::OrtValue, + memory_info: &'m MemoryInfo, + shape: D, + data_type: TensorElementDataType, + ) -> DynOrtTensor<'m, D> { + DynOrtTensor { + tensor_ptr_holder: rc::Rc::from(TensorPointerDropper { tensor_ptr }), + memory_info, + shape, + data_type, + } + } + + /// The ONNX data type this tensor contains. + pub fn data_type(&self) -> TensorElementDataType { + self.data_type + } + + /// Extract a tensor containing `T`. + /// + /// Where the type permits it, the tensor will be a view into existing memory. + /// + /// # Errors + /// + /// An error will be returned if `T`'s ONNX type doesn't match this tensor's type, or if an + /// onnxruntime error occurs. + pub fn try_extract<'t, T>(&self) -> result::Result<OrtOwnedTensor<'t, T, D>, TensorExtractError> + where + T: TensorDataToType + Clone + Debug, + 'm: 't, // mem info outlives tensor + { + if self.data_type != T::tensor_element_data_type() { + Err(TensorExtractError::DataTypeMismatch { + actual: self.data_type, + requested: T::tensor_element_data_type(), + }) + } else { + // Note: Both tensor and array will point to the same data, nothing is copied. + // As such, there is no need to free the pointer used to create the ArrayView. + assert_ne!(self.tensor_ptr_holder.tensor_ptr, ptr::null_mut()); + + let mut is_tensor = 0; + unsafe { + call_ort(|ort| { + ort.IsTensor.unwrap()(self.tensor_ptr_holder.tensor_ptr, &mut is_tensor) + }) + } + .map_err(OrtError::IsTensor)?; + assert_eq!(is_tensor, 1); + + let array_view = + T::extract_array(self.shape.clone(), self.tensor_ptr_holder.tensor_ptr)?; + + Ok(OrtOwnedTensor::new( + self.tensor_ptr_holder.clone(), + array_view, + )) + } + } +} + /// Tensor containing data owned by the ONNX Runtime C library, used to return values from inference. /// /// This tensor type is returned by the [`Session::run()`](../session/struct.Session.html#method.run) method. @@ -23,20 +129,19 @@ use crate::{ /// `OrtOwnedTensor` implements the [`std::deref::Deref`](#impl-Deref) trait for ergonomic access to /// the underlying [`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html). #[derive(Debug)] -pub struct OrtOwnedTensor<'t, 'm, T, D> +pub struct OrtOwnedTensor<'t, T, D> where - T: TypeToTensorElementDataType + Debug + Clone, + T: TensorDataToType, D: ndarray::Dimension, - 'm: 't, // 'm outlives 't { - pub(crate) tensor_ptr: *mut sys::OrtValue, + /// Keep the pointer alive + tensor_ptr_holder: rc::Rc<TensorPointerDropper>, array_view: ArrayView<'t, T, D>, - memory_info: &'m MemoryInfo, } -impl<'t, 'm, T, D> Deref for OrtOwnedTensor<'t, 'm, T, D> +impl<'t, T, D> Deref for OrtOwnedTensor<'t, T, D> where - T: TypeToTensorElementDataType + Debug + Clone, + T: TensorDataToType, D: ndarray::Dimension, { type Target = ArrayView<'t, T, D>; @@ -46,11 +151,21 @@ where } } -impl<'t, 'm, T, D> OrtOwnedTensor<'t, 'm, T, D> +impl<'t, T, D> OrtOwnedTensor<'t, T, D> where - T: TypeToTensorElementDataType + Debug + Clone, + T: TensorDataToType, D: ndarray::Dimension, { + pub(crate) fn new( + tensor_ptr_holder: rc::Rc<TensorPointerDropper>, + array_view: ArrayView<'t, T, D>, + ) -> OrtOwnedTensor<'t, T, D> { + OrtOwnedTensor { + tensor_ptr_holder, + array_view, + } + } + /// Apply a softmax on the specified axis pub fn softmax(&self, axis: ndarray::Axis) -> Array<T, D> where @@ -61,74 +176,23 @@ where } } +/// Holds on to a tensor pointer until dropped. +/// +/// This allows creating an [OrtOwnedTensor] from a [DynOrtTensor] without consuming `self`, which +/// would prevent retrying extraction and also make interacting with outputs `Vec` awkward. +/// It also avoids needing `OrtOwnedTensor` to keep a reference to `DynOrtTensor`, which would be +/// inconvenient. #[derive(Debug)] -pub(crate) struct OrtOwnedTensorExtractor<'m, D> -where - D: ndarray::Dimension, -{ - pub(crate) tensor_ptr: *mut sys::OrtValue, - memory_info: &'m MemoryInfo, - shape: D, +pub(crate) struct TensorPointerDropper { + tensor_ptr: *mut sys::OrtValue, } -impl<'m, D> OrtOwnedTensorExtractor<'m, D> -where - D: ndarray::Dimension, -{ - pub(crate) fn new(memory_info: &'m MemoryInfo, shape: D) -> OrtOwnedTensorExtractor<'m, D> { - OrtOwnedTensorExtractor { - tensor_ptr: std::ptr::null_mut(), - memory_info, - shape, - } - } - - pub(crate) fn extract<'t, T>(self) -> Result<OrtOwnedTensor<'t, 'm, T, D>> - where - T: TypeToTensorElementDataType + Debug + Clone, - { - // Note: Both tensor and array will point to the same data, nothing is copied. - // As such, there is no need too free the pointer used to create the ArrayView. - - assert_ne!(self.tensor_ptr, std::ptr::null_mut()); - - let mut is_tensor = 0; - let status = unsafe { g_ort().IsTensor.unwrap()(self.tensor_ptr, &mut is_tensor) }; - status_to_result(status).map_err(OrtError::IsTensor)?; - assert_eq!(is_tensor, 1); - - // Get pointer to output tensor float values - let mut output_array_ptr: *mut T = std::ptr::null_mut(); - let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr; - let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = - output_array_ptr_ptr as *mut *mut std::ffi::c_void; - let status = unsafe { - g_ort().GetTensorMutableData.unwrap()(self.tensor_ptr, output_array_ptr_ptr_void) - }; - status_to_result(status).map_err(OrtError::IsTensor)?; - assert_ne!(output_array_ptr, std::ptr::null_mut()); - - let array_view = unsafe { ArrayView::from_shape_ptr(self.shape, output_array_ptr) }; - - Ok(OrtOwnedTensor { - tensor_ptr: self.tensor_ptr, - array_view, - memory_info: self.memory_info, - }) - } -} - -impl<'t, 'm, T, D> Drop for OrtOwnedTensor<'t, 'm, T, D> -where - T: TypeToTensorElementDataType + Debug + Clone, - D: ndarray::Dimension, - 'm: 't, // 'm outlives 't -{ +impl Drop for TensorPointerDropper { #[tracing::instrument] fn drop(&mut self) { debug!("Dropping OrtOwnedTensor."); unsafe { g_ort().ReleaseValue.unwrap()(self.tensor_ptr) } - self.tensor_ptr = std::ptr::null_mut(); + self.tensor_ptr = ptr::null_mut(); } } diff --git a/onnxruntime/src/tensor/ort_tensor.rs b/onnxruntime/src/tensor/ort_tensor.rs index 437e2e86..0937afe1 100644 --- a/onnxruntime/src/tensor/ort_tensor.rs +++ b/onnxruntime/src/tensor/ort_tensor.rs @@ -8,9 +8,11 @@ use tracing::{debug, error}; use onnxruntime_sys as sys; use crate::{ - error::call_ort, error::status_to_result, g_ort, memory::MemoryInfo, - tensor::ndarray_tensor::NdArrayTensor, OrtError, Result, TensorElementDataType, - TypeToTensorElementDataType, + error::{call_ort, status_to_result}, + g_ort, + memory::MemoryInfo, + tensor::{ndarray_tensor::NdArrayTensor, TensorElementDataType, TypeToTensorElementDataType}, + OrtError, Result, }; /// Owned tensor, backed by an [`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html) diff --git a/onnxruntime/tests/integration_tests.rs b/onnxruntime/tests/integration_tests.rs index ee531feb..2a2ea164 100644 --- a/onnxruntime/tests/integration_tests.rs +++ b/onnxruntime/tests/integration_tests.rs @@ -15,6 +15,7 @@ mod download { use onnxruntime::{ download::vision::{DomainBasedImageClassification, ImageClassification}, environment::Environment, + tensor::{DynOrtTensor, OrtOwnedTensor}, GraphOptimizationLevel, LoggingLevel, }; @@ -93,13 +94,13 @@ mod download { let input_tensor_values = vec![array]; // Perform the inference - let outputs: Vec< - onnxruntime::tensor::OrtOwnedTensor<f32, ndarray::Dim<ndarray::IxDynImpl>>, - > = session.run(input_tensor_values).unwrap(); + let outputs: Vec<DynOrtTensor<ndarray::Dim<ndarray::IxDynImpl>>> = + session.run(input_tensor_values).unwrap(); // Downloaded model does not have a softmax as final layer; call softmax on second axis // and iterate on resulting probabilities, creating an index to later access labels. - let mut probabilities: Vec<(usize, f32)> = outputs[0] + let output: OrtOwnedTensor<_, _> = outputs[0].try_extract().unwrap(); + let mut probabilities: Vec<(usize, f32)> = output .softmax(ndarray::Axis(1)) .into_iter() .copied() @@ -184,11 +185,11 @@ mod download { let input_tensor_values = vec![array]; // Perform the inference - let outputs: Vec< - onnxruntime::tensor::OrtOwnedTensor<f32, ndarray::Dim<ndarray::IxDynImpl>>, - > = session.run(input_tensor_values).unwrap(); + let outputs: Vec<DynOrtTensor<ndarray::Dim<ndarray::IxDynImpl>>> = + session.run(input_tensor_values).unwrap(); - let mut probabilities: Vec<(usize, f32)> = outputs[0] + let output: OrtOwnedTensor<_, _> = outputs[0].try_extract().unwrap(); + let mut probabilities: Vec<(usize, f32)> = output .softmax(ndarray::Axis(1)) .into_iter() .copied() @@ -282,12 +283,12 @@ mod download { let input_tensor_values = vec![array]; // Perform the inference - let outputs: Vec< - onnxruntime::tensor::OrtOwnedTensor<f32, ndarray::Dim<ndarray::IxDynImpl>>, - > = session.run(input_tensor_values).unwrap(); + let outputs: Vec<DynOrtTensor<ndarray::Dim<ndarray::IxDynImpl>>> = + session.run(input_tensor_values).unwrap(); assert_eq!(outputs.len(), 1); - let output = &outputs[0]; + let output: OrtOwnedTensor<'_, f32, ndarray::Dim<ndarray::IxDynImpl>> = + outputs[0].try_extract().unwrap(); // The image should have doubled in size assert_eq!(output.shape(), [1, 448, 448, 3]); diff --git a/onnxruntime/tests/string_type.rs b/onnxruntime/tests/string_type.rs new file mode 100644 index 00000000..fe4c0da9 --- /dev/null +++ b/onnxruntime/tests/string_type.rs @@ -0,0 +1,48 @@ +use std::error::Error; + +use ndarray; +use onnxruntime::tensor::{OrtOwnedTensor, TensorElementDataType}; +use onnxruntime::{environment::Environment, tensor::DynOrtTensor, LoggingLevel}; + +#[test] +fn run_model_with_string_input_output() -> Result<(), Box<dyn Error>> { + let environment = Environment::builder() + .with_name("test") + .with_log_level(LoggingLevel::Verbose) + .build()?; + + let mut session = environment + .new_session_builder()? + .with_model_from_file("../test-models/tensorflow/unique_model.onnx")?; + + // Inputs: + // 0: + // name = input_1:0 + // type = String + // dimensions = [None] + // Outputs: + // 0: + // name = Identity:0 + // type = Int32 + // dimensions = [None] + // 1: + // name = Identity_1:0 + // type = String + // dimensions = [None] + + let array = ndarray::Array::from(vec!["foo", "bar", "foo", "foo"]); + let input_tensor_values = vec![array]; + + let outputs: Vec<DynOrtTensor<_>> = session.run(input_tensor_values)?; + + assert_eq!(TensorElementDataType::Int32, outputs[0].data_type()); + assert_eq!(TensorElementDataType::String, outputs[1].data_type()); + + let int_output: OrtOwnedTensor<i32, _> = outputs[0].try_extract()?; + + assert_eq!(&[0, 1, 0, 0], int_output.as_slice().unwrap()); + + // TODO get the string output once string extraction is implemented + + Ok(()) +} diff --git a/test-models/tensorflow/.gitignore b/test-models/tensorflow/.gitignore new file mode 100644 index 00000000..aea6a084 --- /dev/null +++ b/test-models/tensorflow/.gitignore @@ -0,0 +1,2 @@ +/Pipfile.lock +/models diff --git a/test-models/tensorflow/Pipfile b/test-models/tensorflow/Pipfile new file mode 100644 index 00000000..a7b370ab --- /dev/null +++ b/test-models/tensorflow/Pipfile @@ -0,0 +1,13 @@ +[[source]] +url = "https://pypi.org/simple" +verify_ssl = true +name = "pypi" + +[packages] +tensorflow = "==2.4.1" +tf2onnx = "==1.8.3" + +[dev-packages] + +[requires] +python_version = "3.8" diff --git a/test-models/tensorflow/README.md b/test-models/tensorflow/README.md new file mode 100644 index 00000000..4f2e68f2 --- /dev/null +++ b/test-models/tensorflow/README.md @@ -0,0 +1,18 @@ +# Setup + +Have Pipenv make the virtualenv for you: + +``` +pipenv install +``` + +# Model: Unique + +A TensorFlow model that removes duplicate tensor elements. + +This supports strings, and doesn't require custom operators. + +``` +pipenv run python src/unique_model.py +pipenv run python -m tf2onnx.convert --saved-model models/unique_model --output unique_model.onnx --opset 11 +``` diff --git a/test-models/tensorflow/src/unique_model.py b/test-models/tensorflow/src/unique_model.py new file mode 100644 index 00000000..fb79dc8b --- /dev/null +++ b/test-models/tensorflow/src/unique_model.py @@ -0,0 +1,19 @@ +import tensorflow as tf +import numpy as np +import tf2onnx + + +class UniqueModel(tf.keras.Model): + + def __init__(self, name='model1', **kwargs): + super(UniqueModel, self).__init__(name=name, **kwargs) + + def call(self, inputs): + return tf.unique(inputs) + + +model1 = UniqueModel() + +print(model1(tf.constant(["foo", "bar", "foo", "baz"]))) + +model1.save("models/unique_model") diff --git a/test-models/tensorflow/unique_model.onnx b/test-models/tensorflow/unique_model.onnx new file mode 100644 index 00000000..320f6200 Binary files /dev/null and b/test-models/tensorflow/unique_model.onnx differ