Skip to content

Commit 88c6bab

Browse files
committed
Merge branch 'remove-panics-from-session'
2 parents a9358fd + b3e1a26 commit 88c6bab

File tree

7 files changed

+102
-46
lines changed

7 files changed

+102
-46
lines changed

onnxruntime/src/environment.rs

+6-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use std::{
66
};
77

88
use lazy_static::lazy_static;
9-
use tracing::{debug, warn};
9+
use tracing::{debug, error, warn};
1010

1111
use onnxruntime_sys as sys;
1212

@@ -182,7 +182,11 @@ impl Drop for Environment {
182182
);
183183

184184
assert_ne!(env_ptr, std::ptr::null_mut());
185-
unsafe { release_env(env_ptr) };
185+
if env_ptr.is_null() {
186+
error!("Environment pointer is null, not dropping!");
187+
} else {
188+
unsafe { release_env(env_ptr) };
189+
}
186190

187191
environment_guard.env_ptr = AtomicPtr::new(std::ptr::null_mut());
188192
environment_guard.name = String::from("uninitialized");

onnxruntime/src/error.rs

+27
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,21 @@ pub enum OrtError {
100100
/// Attempt to build a Rust `CString` from a null pointer
101101
#[error("Failed to build CString when original contains null: {0}")]
102102
CStringNulError(#[from] std::ffi::NulError),
103+
#[error("{0} pointer should be null")]
104+
/// Ort Pointer should have been null
105+
PointerShouldBeNull(String),
106+
/// Ort pointer should not have been null
107+
#[error("{0} pointer should not be null")]
108+
PointerShouldNotBeNull(String),
109+
/// ONNX Model has invalid dimensions
110+
#[error("Invalid dimensions")]
111+
InvalidDimensions,
112+
/// The runtime type was undefined
113+
#[error("Undefined Tensor Element Type")]
114+
UndefinedTensorElementType,
115+
/// Error occurred when checking if ONNX tensor was properly initialized
116+
#[error("Failed to check if tensor")]
117+
IsTensorCheck,
103118
}
104119

105120
/// Error used when dimensions of input (from model and from inference call)
@@ -176,6 +191,18 @@ impl From<*const sys::OrtStatus> for OrtStatusWrapper {
176191
}
177192
}
178193

194+
pub(crate) fn assert_null_pointer<T>(ptr: *const T, name: &str) -> Result<()> {
195+
ptr.is_null()
196+
.then(|| ())
197+
.ok_or_else(|| OrtError::PointerShouldBeNull(name.to_owned()))
198+
}
199+
200+
pub(crate) fn assert_not_null_pointer<T>(ptr: *const T, name: &str) -> Result<()> {
201+
(!ptr.is_null())
202+
.then(|| ())
203+
.ok_or_else(|| OrtError::PointerShouldBeNull(name.to_owned()))
204+
}
205+
179206
impl From<OrtStatusWrapper> for std::result::Result<(), OrtApiError> {
180207
fn from(status: OrtStatusWrapper) -> Self {
181208
if status.0.is_null() {

onnxruntime/src/memory.rs

+10-6
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@ use tracing::debug;
22

33
use onnxruntime_sys as sys;
44

5+
use tracing::error;
6+
57
use crate::{
6-
error::{status_to_result, OrtError, Result},
8+
error::{assert_not_null_pointer, status_to_result, OrtError, Result},
79
g_ort, AllocatorType, MemType,
810
};
911

@@ -25,7 +27,7 @@ impl MemoryInfo {
2527
)
2628
};
2729
status_to_result(status).map_err(OrtError::CreateCpuMemoryInfo)?;
28-
assert_ne!(memory_info_ptr, std::ptr::null_mut());
30+
assert_not_null_pointer(memory_info_ptr, "MemoryInfo")?;
2931

3032
Ok(Self {
3133
ptr: memory_info_ptr,
@@ -36,10 +38,12 @@ impl MemoryInfo {
3638
impl Drop for MemoryInfo {
3739
#[tracing::instrument]
3840
fn drop(&mut self) {
39-
debug!("Dropping the memory information.");
40-
assert_ne!(self.ptr, std::ptr::null_mut());
41-
42-
unsafe { g_ort().ReleaseMemoryInfo.unwrap()(self.ptr) };
41+
if self.ptr.is_null() {
42+
error!("MemoryInfo pointer is null, not dropping.");
43+
} else {
44+
debug!("Dropping the memory information.");
45+
unsafe { g_ort().ReleaseMemoryInfo.unwrap()(self.ptr) };
46+
}
4347

4448
self.ptr = std::ptr::null_mut();
4549
}

onnxruntime/src/session.rs

+43-29
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@ use onnxruntime_sys as sys;
1818
use crate::{
1919
char_p_to_string,
2020
environment::Environment,
21-
error::{status_to_result, NonMatchingDimensionsError, OrtError, Result},
21+
error::{
22+
assert_not_null_pointer, assert_null_pointer, status_to_result, NonMatchingDimensionsError,
23+
OrtApiError, OrtError, Result,
24+
},
2225
g_ort,
2326
memory::MemoryInfo,
2427
tensor::{
@@ -73,9 +76,12 @@ pub struct SessionBuilder<'a> {
7376
impl<'a> Drop for SessionBuilder<'a> {
7477
#[tracing::instrument]
7578
fn drop(&mut self) {
76-
debug!("Dropping the session options.");
77-
assert_ne!(self.session_options_ptr, std::ptr::null_mut());
78-
unsafe { g_ort().ReleaseSessionOptions.unwrap()(self.session_options_ptr) };
79+
if self.session_options_ptr.is_null() {
80+
error!("Session options pointer is null, not dropping");
81+
} else {
82+
debug!("Dropping the session options.");
83+
unsafe { g_ort().ReleaseSessionOptions.unwrap()(self.session_options_ptr) };
84+
}
7985
}
8086
}
8187

@@ -85,8 +91,8 @@ impl<'a> SessionBuilder<'a> {
8591
let status = unsafe { g_ort().CreateSessionOptions.unwrap()(&mut session_options_ptr) };
8692

8793
status_to_result(status).map_err(OrtError::SessionOptions)?;
88-
assert_eq!(status, std::ptr::null_mut());
89-
assert_ne!(session_options_ptr, std::ptr::null_mut());
94+
assert_null_pointer(status, "SessionStatus")?;
95+
assert_not_null_pointer(session_options_ptr, "SessionOptions")?;
9096

9197
Ok(SessionBuilder {
9298
env,
@@ -105,7 +111,7 @@ impl<'a> SessionBuilder<'a> {
105111
let status =
106112
unsafe { g_ort().SetIntraOpNumThreads.unwrap()(self.session_options_ptr, num_threads) };
107113
status_to_result(status).map_err(OrtError::SessionOptions)?;
108-
assert_eq!(status, std::ptr::null_mut());
114+
assert_null_pointer(status, "SessionStatus")?;
109115
Ok(self)
110116
}
111117

@@ -199,14 +205,14 @@ impl<'a> SessionBuilder<'a> {
199205
)
200206
};
201207
status_to_result(status).map_err(OrtError::Session)?;
202-
assert_eq!(status, std::ptr::null_mut());
203-
assert_ne!(session_ptr, std::ptr::null_mut());
208+
assert_null_pointer(status, "SessionStatus")?;
209+
assert_not_null_pointer(session_ptr, "Session")?;
204210

205211
let mut allocator_ptr: *mut sys::OrtAllocator = std::ptr::null_mut();
206212
let status = unsafe { g_ort().GetAllocatorWithDefaultOptions.unwrap()(&mut allocator_ptr) };
207213
status_to_result(status).map_err(OrtError::Allocator)?;
208-
assert_eq!(status, std::ptr::null_mut());
209-
assert_ne!(allocator_ptr, std::ptr::null_mut());
214+
assert_null_pointer(status, "SessionStatus")?;
215+
assert_not_null_pointer(allocator_ptr, "Allocator")?;
210216

211217
let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default)?;
212218

@@ -255,14 +261,14 @@ impl<'a> SessionBuilder<'a> {
255261
)
256262
};
257263
status_to_result(status).map_err(OrtError::Session)?;
258-
assert_eq!(status, std::ptr::null_mut());
259-
assert_ne!(session_ptr, std::ptr::null_mut());
264+
assert_null_pointer(status, "SessionStatus")?;
265+
assert_not_null_pointer(session_ptr, "Session")?;
260266

261267
let mut allocator_ptr: *mut sys::OrtAllocator = std::ptr::null_mut();
262268
let status = unsafe { g_ort().GetAllocatorWithDefaultOptions.unwrap()(&mut allocator_ptr) };
263269
status_to_result(status).map_err(OrtError::Allocator)?;
264-
assert_eq!(status, std::ptr::null_mut());
265-
assert_ne!(allocator_ptr, std::ptr::null_mut());
270+
assert_null_pointer(status, "SessionStatus")?;
271+
assert_not_null_pointer(allocator_ptr, "Allocator")?;
266272

267273
let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default)?;
268274

@@ -352,7 +358,11 @@ impl<'a> Drop for Session<'a> {
352358
#[tracing::instrument]
353359
fn drop(&mut self) {
354360
debug!("Dropping the session.");
355-
unsafe { g_ort().ReleaseSession.unwrap()(self.session_ptr) };
361+
if self.session_ptr.is_null() {
362+
error!("Session pointer is null, not dropping.");
363+
} else {
364+
unsafe { g_ort().ReleaseSession.unwrap()(self.session_ptr) };
365+
}
356366
// FIXME: There is no C function to release the allocator?
357367

358368
self.session_ptr = std::ptr::null_mut();
@@ -453,13 +463,14 @@ impl<'a> Session<'a> {
453463
.collect();
454464

455465
// Reconvert to CString so drop impl is called and memory is freed
456-
let _: Vec<CString> = input_names_ptr
466+
let cstrings: Result<Vec<CString>> = input_names_ptr
457467
.into_iter()
458468
.map(|p| {
459-
assert_ne!(p, std::ptr::null());
460-
unsafe { CString::from_raw(p as *mut i8) }
469+
assert_not_null_pointer(p, "i8 for CString")?;
470+
unsafe { Ok(CString::from_raw(p as *mut i8)) }
461471
})
462472
.collect();
473+
cstrings?;
463474

464475
outputs
465476
}
@@ -568,7 +579,9 @@ unsafe fn get_tensor_dimensions(
568579
let mut num_dims = 0;
569580
let status = g_ort().GetDimensionsCount.unwrap()(tensor_info_ptr, &mut num_dims);
570581
status_to_result(status).map_err(OrtError::GetDimensionsCount)?;
571-
assert_ne!(num_dims, 0);
582+
(num_dims != 0)
583+
.then(|| ())
584+
.ok_or(OrtError::InvalidDimensions)?;
572585

573586
let mut node_dims: Vec<i64> = vec![0; num_dims as usize];
574587
let status = g_ort().GetDimensions.unwrap()(
@@ -603,8 +616,10 @@ mod dangerous {
603616
let mut num_nodes: usize = 0;
604617
let status = unsafe { f(session_ptr, &mut num_nodes) };
605618
status_to_result(status).map_err(OrtError::InOutCount)?;
606-
assert_eq!(status, std::ptr::null_mut());
607-
assert_ne!(num_nodes, 0);
619+
assert_null_pointer(status, "SessionStatus")?;
620+
(num_nodes != 0).then(|| ()).ok_or_else(|| {
621+
OrtError::InOutCount(OrtApiError::Msg("No nodes in model".to_owned()))
622+
})?;
608623
Ok(num_nodes)
609624
}
610625

@@ -641,7 +656,7 @@ mod dangerous {
641656

642657
let status = unsafe { f(session_ptr, i, allocator_ptr, &mut name_bytes) };
643658
status_to_result(status).map_err(OrtError::InputName)?;
644-
assert_ne!(name_bytes, std::ptr::null_mut());
659+
assert_not_null_pointer(name_bytes, "InputName")?;
645660

646661
// FIXME: Is it safe to keep ownership of the memory?
647662
let name = char_p_to_string(name_bytes)?;
@@ -692,23 +707,22 @@ mod dangerous {
692707

693708
let status = unsafe { f(session_ptr, i, &mut typeinfo_ptr) };
694709
status_to_result(status).map_err(OrtError::GetTypeInfo)?;
695-
assert_ne!(typeinfo_ptr, std::ptr::null_mut());
710+
assert_not_null_pointer(typeinfo_ptr, "TypeInfo")?;
696711

697712
let mut tensor_info_ptr: *const sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut();
698713
let status = unsafe {
699714
g_ort().CastTypeInfoToTensorInfo.unwrap()(typeinfo_ptr, &mut tensor_info_ptr)
700715
};
701716
status_to_result(status).map_err(OrtError::CastTypeInfoToTensorInfo)?;
702-
assert_ne!(tensor_info_ptr, std::ptr::null_mut());
717+
assert_not_null_pointer(tensor_info_ptr, "TensorInfo")?;
703718

704719
let mut type_sys = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
705720
let status =
706721
unsafe { g_ort().GetTensorElementType.unwrap()(tensor_info_ptr, &mut type_sys) };
707722
status_to_result(status).map_err(OrtError::TensorElementType)?;
708-
assert_ne!(
709-
type_sys,
710-
sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED
711-
);
723+
(type_sys != sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED)
724+
.then(|| ())
725+
.ok_or(OrtError::UndefinedTensorElementType)?;
712726
// This transmute should be safe since its value is read from GetTensorElementType which we must trust.
713727
let io_type: TensorElementDataType = unsafe { std::mem::transmute(type_sys) };
714728

onnxruntime/src/tensor/ort_owned_tensor.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ where
9595
let mut is_tensor = 0;
9696
let status = unsafe { g_ort().IsTensor.unwrap()(self.tensor_ptr, &mut is_tensor) };
9797
status_to_result(status).map_err(OrtError::IsTensor)?;
98-
assert_eq!(is_tensor, 1);
98+
(is_tensor == 1)
99+
.then(|| ())
100+
.ok_or(OrtError::IsTensorCheck)?;
99101

100102
// Get pointer to output tensor float values
101103
let mut output_array_ptr: *mut T = std::ptr::null_mut();

onnxruntime/src/tensor/ort_tensor.rs

+8-6
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@ use tracing::{debug, error};
88
use onnxruntime_sys as sys;
99

1010
use crate::{
11-
error::call_ort, error::status_to_result, g_ort, memory::MemoryInfo,
12-
tensor::ndarray_tensor::NdArrayTensor, OrtError, Result, TensorElementDataType,
13-
TypeToTensorElementDataType,
11+
error::{assert_not_null_pointer, call_ort, status_to_result},
12+
g_ort,
13+
memory::MemoryInfo,
14+
tensor::ndarray_tensor::NdArrayTensor,
15+
OrtError, Result, TensorElementDataType, TypeToTensorElementDataType,
1416
};
1517

1618
/// Owned tensor, backed by an [`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html)
@@ -67,7 +69,7 @@ where
6769
// onnxruntime as is
6870
let tensor_values_ptr: *mut std::ffi::c_void =
6971
array.as_mut_ptr() as *mut std::ffi::c_void;
70-
assert_ne!(tensor_values_ptr, std::ptr::null_mut());
72+
assert_not_null_pointer(tensor_values_ptr, "TensorValues")?;
7173

7274
unsafe {
7375
call_ort(|ort| {
@@ -83,7 +85,7 @@ where
8385
})
8486
}
8587
.map_err(OrtError::CreateTensorWithData)?;
86-
assert_ne!(tensor_ptr, std::ptr::null_mut());
88+
assert_not_null_pointer(tensor_ptr, "Tensor")?;
8789

8890
let mut is_tensor = 0;
8991
let status = unsafe { g_ort().IsTensor.unwrap()(tensor_ptr, &mut is_tensor) };
@@ -134,7 +136,7 @@ where
134136
}
135137
}
136138

137-
assert_ne!(tensor_ptr, std::ptr::null_mut());
139+
assert_not_null_pointer(tensor_ptr, "Tensor")?;
138140

139141
Ok(OrtTensor {
140142
c_ptr: tensor_ptr,

onnxruntime/tests/integration_tests.rs

+5-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use std::{
66
};
77

88
use onnxruntime::error::OrtDownloadError;
9+
use onnxruntime::tensor::OrtOwnedTensor;
910

1011
mod download {
1112
use super::*;
@@ -101,7 +102,8 @@ mod download {
101102

102103
// Downloaded model does not have a softmax as final layer; call softmax on second axis
103104
// and iterate on resulting probabilities, creating an index to later access labels.
104-
let mut probabilities: Vec<(usize, f32)> = outputs[0]
105+
let output: &OrtOwnedTensor<f32, _> = &outputs[0];
106+
let mut probabilities: Vec<(usize, f32)> = output
105107
.softmax(ndarray::Axis(1))
106108
.iter()
107109
.copied()
@@ -190,7 +192,8 @@ mod download {
190192
onnxruntime::tensor::OrtOwnedTensor<f32, ndarray::Dim<ndarray::IxDynImpl>>,
191193
> = session.run(input_tensor_values).unwrap();
192194

193-
let mut probabilities: Vec<(usize, f32)> = outputs[0]
195+
let output: &OrtOwnedTensor<f32, _> = &outputs[0];
196+
let mut probabilities: Vec<(usize, f32)> = output
194197
.softmax(ndarray::Axis(1))
195198
.iter()
196199
.copied()

0 commit comments

Comments
 (0)