Skip to content

Commit 0d34d23

Browse files
Introduce a new trait to represent types that can be used as output from a tensor
This is some prep work for string output types and tensor types that vary across the model outputs. For now, the supported types are just the basic numeric types. Since strings have to be copied out of a tensor, it only makes sense to have `String` be an output type, not `str`, hence the new type so that we can have more input types supported than output types.
1 parent 3b5fcd3 commit 0d34d23

File tree

5 files changed

+322
-255
lines changed

5 files changed

+322
-255
lines changed

onnxruntime/src/lib.rs

-148
Original file line numberDiff line numberDiff line change
@@ -322,154 +322,6 @@ impl Into<sys::GraphOptimizationLevel> for GraphOptimizationLevel {
322322
}
323323
}
324324

325-
// FIXME: Use https://docs.rs/bindgen/0.54.1/bindgen/struct.Builder.html#method.rustified_enum
326-
// FIXME: Add tests to cover the commented out types
327-
/// Enum mapping ONNX Runtime's supported tensor types
328-
#[derive(Debug)]
329-
#[cfg_attr(not(windows), repr(u32))]
330-
#[cfg_attr(windows, repr(i32))]
331-
pub enum TensorElementDataType {
332-
/// 32-bit floating point, equivalent to Rust's `f32`
333-
Float = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT as OnnxEnumInt,
334-
/// Unsigned 8-bit int, equivalent to Rust's `u8`
335-
Uint8 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 as OnnxEnumInt,
336-
/// Signed 8-bit int, equivalent to Rust's `i8`
337-
Int8 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 as OnnxEnumInt,
338-
/// Unsigned 16-bit int, equivalent to Rust's `u16`
339-
Uint16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 as OnnxEnumInt,
340-
/// Signed 16-bit int, equivalent to Rust's `i16`
341-
Int16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 as OnnxEnumInt,
342-
/// Signed 32-bit int, equivalent to Rust's `i32`
343-
Int32 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 as OnnxEnumInt,
344-
/// Signed 64-bit int, equivalent to Rust's `i64`
345-
Int64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 as OnnxEnumInt,
346-
/// String, equivalent to Rust's `String`
347-
String = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING as OnnxEnumInt,
348-
// /// Boolean, equivalent to Rust's `bool`
349-
// Bool = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL as OnnxEnumInt,
350-
// /// 16-bit floating point, equivalent to Rust's `f16`
351-
// Float16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 as OnnxEnumInt,
352-
/// 64-bit floating point, equivalent to Rust's `f64`
353-
Double = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE as OnnxEnumInt,
354-
/// Unsigned 32-bit int, equivalent to Rust's `u32`
355-
Uint32 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 as OnnxEnumInt,
356-
/// Unsigned 64-bit int, equivalent to Rust's `u64`
357-
Uint64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 as OnnxEnumInt,
358-
// /// Complex 64-bit floating point, equivalent to Rust's `???`
359-
// Complex64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 as OnnxEnumInt,
360-
// /// Complex 128-bit floating point, equivalent to Rust's `???`
361-
// Complex128 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 as OnnxEnumInt,
362-
// /// Brain 16-bit floating point
363-
// Bfloat16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 as OnnxEnumInt,
364-
}
365-
366-
impl Into<sys::ONNXTensorElementDataType> for TensorElementDataType {
367-
fn into(self) -> sys::ONNXTensorElementDataType {
368-
use TensorElementDataType::*;
369-
match self {
370-
Float => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
371-
Uint8 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8,
372-
Int8 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8,
373-
Uint16 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16,
374-
Int16 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16,
375-
Int32 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32,
376-
Int64 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
377-
String => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING,
378-
// Bool => {
379-
// sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL
380-
// }
381-
// Float16 => {
382-
// sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16
383-
// }
384-
Double => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE,
385-
Uint32 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32,
386-
Uint64 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64,
387-
// Complex64 => {
388-
// sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64
389-
// }
390-
// Complex128 => {
391-
// sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128
392-
// }
393-
// Bfloat16 => {
394-
// sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16
395-
// }
396-
}
397-
}
398-
}
399-
400-
/// Trait used to map Rust types (for example `f32`) to ONNX types (for example `Float`)
401-
pub trait TypeToTensorElementDataType {
402-
/// Return the ONNX type for a Rust type
403-
fn tensor_element_data_type() -> TensorElementDataType;
404-
405-
/// If the type is `String`, returns `Some` with utf8 contents, else `None`.
406-
fn try_utf8_bytes(&self) -> Option<&[u8]>;
407-
}
408-
409-
macro_rules! impl_type_trait {
410-
($type_:ty, $variant:ident) => {
411-
impl TypeToTensorElementDataType for $type_ {
412-
fn tensor_element_data_type() -> TensorElementDataType {
413-
// unsafe { std::mem::transmute(TensorElementDataType::$variant) }
414-
TensorElementDataType::$variant
415-
}
416-
417-
fn try_utf8_bytes(&self) -> Option<&[u8]> {
418-
None
419-
}
420-
}
421-
};
422-
}
423-
424-
impl_type_trait!(f32, Float);
425-
impl_type_trait!(u8, Uint8);
426-
impl_type_trait!(i8, Int8);
427-
impl_type_trait!(u16, Uint16);
428-
impl_type_trait!(i16, Int16);
429-
impl_type_trait!(i32, Int32);
430-
impl_type_trait!(i64, Int64);
431-
// impl_type_trait!(bool, Bool);
432-
// impl_type_trait!(f16, Float16);
433-
impl_type_trait!(f64, Double);
434-
impl_type_trait!(u32, Uint32);
435-
impl_type_trait!(u64, Uint64);
436-
// impl_type_trait!(, Complex64);
437-
// impl_type_trait!(, Complex128);
438-
// impl_type_trait!(, Bfloat16);
439-
440-
/// Adapter for common Rust string types to Onnx strings.
441-
///
442-
/// It should be easy to use both `String` and `&str` as [TensorElementDataType::String] data, but
443-
/// we can't define an automatic implementation for anything that implements `AsRef<str>` as it
444-
/// would conflict with the implementations of [TypeToTensorElementDataType] for primitive numeric
445-
/// types (which might implement `AsRef<str>` at some point in the future).
446-
pub trait Utf8Data {
447-
/// Returns the utf8 contents.
448-
fn utf8_bytes(&self) -> &[u8];
449-
}
450-
451-
impl Utf8Data for String {
452-
fn utf8_bytes(&self) -> &[u8] {
453-
self.as_bytes()
454-
}
455-
}
456-
457-
impl<'a> Utf8Data for &'a str {
458-
fn utf8_bytes(&self) -> &[u8] {
459-
self.as_bytes()
460-
}
461-
}
462-
463-
impl<T: Utf8Data> TypeToTensorElementDataType for T {
464-
fn tensor_element_data_type() -> TensorElementDataType {
465-
TensorElementDataType::String
466-
}
467-
468-
fn try_utf8_bytes(&self) -> Option<&[u8]> {
469-
Some(self.utf8_bytes())
470-
}
471-
}
472-
473325
/// Allocator type
474326
#[derive(Debug, Clone)]
475327
#[repr(i32)]

onnxruntime/src/session.rs

+73-39
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,14 @@ use onnxruntime_sys as sys;
1818
use crate::{
1919
char_p_to_string,
2020
environment::Environment,
21-
error::{status_to_result, NonMatchingDimensionsError, OrtError, Result},
21+
error::{call_ort, status_to_result, NonMatchingDimensionsError, OrtError, Result},
2222
g_ort,
2323
memory::MemoryInfo,
2424
tensor::{
25-
ort_owned_tensor::{OrtOwnedTensor, OrtOwnedTensorExtractor},
26-
OrtTensor,
25+
ort_owned_tensor::OrtOwnedTensor, OrtTensor, TensorDataToType, TensorElementDataType,
26+
TypeToTensorElementDataType,
2727
},
28-
AllocatorType, GraphOptimizationLevel, MemType, TensorElementDataType,
29-
TypeToTensorElementDataType,
28+
AllocatorType, GraphOptimizationLevel, MemType,
3029
};
3130

3231
#[cfg(feature = "model-fetching")]
@@ -371,7 +370,7 @@ impl<'a> Session<'a> {
371370
) -> Result<Vec<OrtOwnedTensor<'t, 'm, TOut, ndarray::IxDyn>>>
372371
where
373372
TIn: TypeToTensorElementDataType + Debug + Clone,
374-
TOut: TypeToTensorElementDataType + Debug + Clone,
373+
TOut: TensorDataToType,
375374
D: ndarray::Dimension,
376375
'm: 't, // 'm outlives 't (memory info outlives tensor)
377376
's: 'm, // 's outlives 'm (session outlives memory info)
@@ -440,21 +439,30 @@ impl<'a> Session<'a> {
440439
let outputs: Result<Vec<OrtOwnedTensor<TOut, ndarray::Dim<ndarray::IxDynImpl>>>> =
441440
output_tensor_extractors_ptrs
442441
.into_iter()
443-
.map(|ptr| {
444-
let mut tensor_info_ptr: *mut sys::OrtTensorTypeAndShapeInfo =
445-
std::ptr::null_mut();
446-
let status = unsafe {
447-
g_ort().GetTensorTypeAndShape.unwrap()(ptr, &mut tensor_info_ptr as _)
448-
};
449-
status_to_result(status).map_err(OrtError::GetTensorTypeAndShape)?;
450-
let dims = unsafe { get_tensor_dimensions(tensor_info_ptr) };
451-
unsafe { g_ort().ReleaseTensorTypeAndShapeInfo.unwrap()(tensor_info_ptr) };
452-
let dims: Vec<_> = dims?.iter().map(|&n| n as usize).collect();
453-
454-
let mut output_tensor_extractor =
455-
OrtOwnedTensorExtractor::new(memory_info_ref, ndarray::IxDyn(&dims));
456-
output_tensor_extractor.tensor_ptr = ptr;
457-
output_tensor_extractor.extract::<TOut>()
442+
.map(|tensor_ptr| {
443+
let dims = unsafe {
444+
call_with_tensor_info(tensor_ptr, |tensor_info_ptr| {
445+
get_tensor_dimensions(tensor_info_ptr)
446+
.map(|dims| dims.iter().map(|&n| n as usize).collect::<Vec<_>>())
447+
})
448+
}?;
449+
450+
// Note: Both tensor and array will point to the same data, nothing is copied.
451+
// As such, there is no need to free the pointer used to create the ArrayView.
452+
assert_ne!(tensor_ptr, std::ptr::null_mut());
453+
454+
let mut is_tensor = 0;
455+
unsafe { call_ort(|ort| ort.IsTensor.unwrap()(tensor_ptr, &mut is_tensor)) }
456+
.map_err(OrtError::IsTensor)?;
457+
assert_eq!(is_tensor, 1);
458+
459+
let array_view = TOut::extract_array(ndarray::IxDyn(&dims), tensor_ptr)?;
460+
461+
Ok(OrtOwnedTensor::new(
462+
tensor_ptr,
463+
array_view,
464+
&memory_info_ref,
465+
))
458466
})
459467
.collect();
460468

@@ -554,25 +562,60 @@ unsafe fn get_tensor_dimensions(
554562
tensor_info_ptr: *const sys::OrtTensorTypeAndShapeInfo,
555563
) -> Result<Vec<i64>> {
556564
let mut num_dims = 0;
557-
let status = g_ort().GetDimensionsCount.unwrap()(tensor_info_ptr, &mut num_dims);
558-
status_to_result(status).map_err(OrtError::GetDimensionsCount)?;
565+
call_ort(|ort| ort.GetDimensionsCount.unwrap()(tensor_info_ptr, &mut num_dims))
566+
.map_err(OrtError::GetDimensionsCount)?;
559567
assert_ne!(num_dims, 0);
560568

561569
let mut node_dims: Vec<i64> = vec![0; num_dims as usize];
562-
let status = g_ort().GetDimensions.unwrap()(
563-
tensor_info_ptr,
564-
node_dims.as_mut_ptr(), // FIXME: UB?
565-
num_dims,
566-
);
567-
status_to_result(status).map_err(OrtError::GetDimensions)?;
570+
call_ort(|ort| {
571+
ort.GetDimensions.unwrap()(
572+
tensor_info_ptr,
573+
node_dims.as_mut_ptr(), // FIXME: UB?
574+
num_dims,
575+
)
576+
})
577+
.map_err(OrtError::GetDimensions)?;
568578
Ok(node_dims)
569579
}
570580

581+
unsafe fn extract_data_type(
582+
tensor_info_ptr: *const sys::OrtTensorTypeAndShapeInfo,
583+
) -> Result<TensorElementDataType> {
584+
let mut type_sys = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
585+
call_ort(|ort| ort.GetTensorElementType.unwrap()(tensor_info_ptr, &mut type_sys))
586+
.map_err(OrtError::TensorElementType)?;
587+
assert_ne!(
588+
type_sys,
589+
sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED
590+
);
591+
// This transmute should be safe since its value is read from GetTensorElementType which we must trust.
592+
Ok(std::mem::transmute(type_sys))
593+
}
594+
595+
/// Calls the provided closure with the result of `GetTensorTypeAndShape`, deallocating the
596+
/// resulting `*OrtTensorTypeAndShapeInfo` before returning.
597+
unsafe fn call_with_tensor_info<F, T>(tensor_ptr: *const sys::OrtValue, mut f: F) -> Result<T>
598+
where
599+
F: FnMut(*const sys::OrtTensorTypeAndShapeInfo) -> Result<T>,
600+
{
601+
let mut tensor_info_ptr: *mut sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut();
602+
call_ort(|ort| ort.GetTensorTypeAndShape.unwrap()(tensor_ptr, &mut tensor_info_ptr as _))
603+
.map_err(OrtError::GetTensorTypeAndShape)?;
604+
605+
let res = f(tensor_info_ptr);
606+
607+
// no return code, so no errors to check for
608+
g_ort().ReleaseTensorTypeAndShapeInfo.unwrap()(tensor_info_ptr);
609+
610+
res
611+
}
612+
571613
/// This module contains dangerous functions working on raw pointers.
572614
/// Those functions are only to be used from inside the
573615
/// `SessionBuilder::with_model_from_file()` method.
574616
mod dangerous {
575617
use super::*;
618+
use crate::tensor::TensorElementDataType;
576619

577620
pub(super) fn extract_inputs_count(session_ptr: *mut sys::OrtSession) -> Result<u64> {
578621
let f = g_ort().SessionGetInputCount.unwrap();
@@ -689,16 +732,7 @@ mod dangerous {
689732
status_to_result(status).map_err(OrtError::CastTypeInfoToTensorInfo)?;
690733
assert_ne!(tensor_info_ptr, std::ptr::null_mut());
691734

692-
let mut type_sys = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
693-
let status =
694-
unsafe { g_ort().GetTensorElementType.unwrap()(tensor_info_ptr, &mut type_sys) };
695-
status_to_result(status).map_err(OrtError::TensorElementType)?;
696-
assert_ne!(
697-
type_sys,
698-
sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED
699-
);
700-
// This transmute should be safe since its value is read from GetTensorElementType which we must trust.
701-
let io_type: TensorElementDataType = unsafe { std::mem::transmute(type_sys) };
735+
let io_type: TensorElementDataType = unsafe { extract_data_type(tensor_info_ptr)? };
702736

703737
// info!("{} : type={}", i, type_);
704738

0 commit comments

Comments
 (0)