@@ -18,7 +18,10 @@ use onnxruntime_sys as sys;
18
18
use crate :: {
19
19
char_p_to_string,
20
20
environment:: Environment ,
21
- error:: { status_to_result, NonMatchingDimensionsError , OrtError , Result } ,
21
+ error:: {
22
+ assert_not_null_pointer, assert_null_pointer, status_to_result, NonMatchingDimensionsError ,
23
+ OrtApiError , OrtError , Result ,
24
+ } ,
22
25
g_ort,
23
26
memory:: MemoryInfo ,
24
27
tensor:: {
@@ -73,9 +76,12 @@ pub struct SessionBuilder<'a> {
73
76
impl < ' a > Drop for SessionBuilder < ' a > {
74
77
#[ tracing:: instrument]
75
78
fn drop ( & mut self ) {
76
- debug ! ( "Dropping the session options." ) ;
77
- assert_ne ! ( self . session_options_ptr, std:: ptr:: null_mut( ) ) ;
78
- unsafe { g_ort ( ) . ReleaseSessionOptions . unwrap ( ) ( self . session_options_ptr ) } ;
79
+ if self . session_options_ptr . is_null ( ) {
80
+ error ! ( "Session options pointer is null, not dropping" ) ;
81
+ } else {
82
+ debug ! ( "Dropping the session options." ) ;
83
+ unsafe { g_ort ( ) . ReleaseSessionOptions . unwrap ( ) ( self . session_options_ptr ) } ;
84
+ }
79
85
}
80
86
}
81
87
@@ -85,8 +91,8 @@ impl<'a> SessionBuilder<'a> {
85
91
let status = unsafe { g_ort ( ) . CreateSessionOptions . unwrap ( ) ( & mut session_options_ptr) } ;
86
92
87
93
status_to_result ( status) . map_err ( OrtError :: SessionOptions ) ?;
88
- assert_eq ! ( status, std :: ptr :: null_mut ( ) ) ;
89
- assert_ne ! ( session_options_ptr, std :: ptr :: null_mut ( ) ) ;
94
+ assert_null_pointer ( status, "SessionStatus" ) ? ;
95
+ assert_not_null_pointer ( session_options_ptr, "SessionOptions" ) ? ;
90
96
91
97
Ok ( SessionBuilder {
92
98
env,
@@ -105,7 +111,7 @@ impl<'a> SessionBuilder<'a> {
105
111
let status =
106
112
unsafe { g_ort ( ) . SetIntraOpNumThreads . unwrap ( ) ( self . session_options_ptr , num_threads) } ;
107
113
status_to_result ( status) . map_err ( OrtError :: SessionOptions ) ?;
108
- assert_eq ! ( status, std :: ptr :: null_mut ( ) ) ;
114
+ assert_null_pointer ( status, "SessionStatus" ) ? ;
109
115
Ok ( self )
110
116
}
111
117
@@ -199,14 +205,14 @@ impl<'a> SessionBuilder<'a> {
199
205
)
200
206
} ;
201
207
status_to_result ( status) . map_err ( OrtError :: Session ) ?;
202
- assert_eq ! ( status, std :: ptr :: null_mut ( ) ) ;
203
- assert_ne ! ( session_ptr, std :: ptr :: null_mut ( ) ) ;
208
+ assert_null_pointer ( status, "SessionStatus" ) ? ;
209
+ assert_not_null_pointer ( session_ptr, "Session" ) ? ;
204
210
205
211
let mut allocator_ptr: * mut sys:: OrtAllocator = std:: ptr:: null_mut ( ) ;
206
212
let status = unsafe { g_ort ( ) . GetAllocatorWithDefaultOptions . unwrap ( ) ( & mut allocator_ptr) } ;
207
213
status_to_result ( status) . map_err ( OrtError :: Allocator ) ?;
208
- assert_eq ! ( status, std :: ptr :: null_mut ( ) ) ;
209
- assert_ne ! ( allocator_ptr, std :: ptr :: null_mut ( ) ) ;
214
+ assert_null_pointer ( status, "SessionStatus" ) ? ;
215
+ assert_not_null_pointer ( allocator_ptr, "Allocator" ) ? ;
210
216
211
217
let memory_info = MemoryInfo :: new ( AllocatorType :: Arena , MemType :: Default ) ?;
212
218
@@ -255,14 +261,14 @@ impl<'a> SessionBuilder<'a> {
255
261
)
256
262
} ;
257
263
status_to_result ( status) . map_err ( OrtError :: Session ) ?;
258
- assert_eq ! ( status, std :: ptr :: null_mut ( ) ) ;
259
- assert_ne ! ( session_ptr, std :: ptr :: null_mut ( ) ) ;
264
+ assert_null_pointer ( status, "SessionStatus" ) ? ;
265
+ assert_not_null_pointer ( session_ptr, "Session" ) ? ;
260
266
261
267
let mut allocator_ptr: * mut sys:: OrtAllocator = std:: ptr:: null_mut ( ) ;
262
268
let status = unsafe { g_ort ( ) . GetAllocatorWithDefaultOptions . unwrap ( ) ( & mut allocator_ptr) } ;
263
269
status_to_result ( status) . map_err ( OrtError :: Allocator ) ?;
264
- assert_eq ! ( status, std :: ptr :: null_mut ( ) ) ;
265
- assert_ne ! ( allocator_ptr, std :: ptr :: null_mut ( ) ) ;
270
+ assert_null_pointer ( status, "SessionStatus" ) ? ;
271
+ assert_not_null_pointer ( allocator_ptr, "Allocator" ) ? ;
266
272
267
273
let memory_info = MemoryInfo :: new ( AllocatorType :: Arena , MemType :: Default ) ?;
268
274
@@ -352,7 +358,11 @@ impl<'a> Drop for Session<'a> {
352
358
#[ tracing:: instrument]
353
359
fn drop ( & mut self ) {
354
360
debug ! ( "Dropping the session." ) ;
355
- unsafe { g_ort ( ) . ReleaseSession . unwrap ( ) ( self . session_ptr ) } ;
361
+ if self . session_ptr . is_null ( ) {
362
+ error ! ( "Session pointer is null, not dropping." ) ;
363
+ } else {
364
+ unsafe { g_ort ( ) . ReleaseSession . unwrap ( ) ( self . session_ptr ) } ;
365
+ }
356
366
// FIXME: There is no C function to release the allocator?
357
367
358
368
self . session_ptr = std:: ptr:: null_mut ( ) ;
@@ -453,13 +463,14 @@ impl<'a> Session<'a> {
453
463
. collect ( ) ;
454
464
455
465
// Reconvert to CString so drop impl is called and memory is freed
456
- let _ : Vec < CString > = input_names_ptr
466
+ let cstrings : Result < Vec < CString > > = input_names_ptr
457
467
. into_iter ( )
458
468
. map ( |p| {
459
- assert_ne ! ( p, std :: ptr :: null ( ) ) ;
460
- unsafe { CString :: from_raw ( p as * mut i8 ) }
469
+ assert_not_null_pointer ( p, "i8 for CString" ) ? ;
470
+ unsafe { Ok ( CString :: from_raw ( p as * mut i8 ) ) }
461
471
} )
462
472
. collect ( ) ;
473
+ cstrings?;
463
474
464
475
outputs
465
476
}
@@ -568,7 +579,9 @@ unsafe fn get_tensor_dimensions(
568
579
let mut num_dims = 0 ;
569
580
let status = g_ort ( ) . GetDimensionsCount . unwrap ( ) ( tensor_info_ptr, & mut num_dims) ;
570
581
status_to_result ( status) . map_err ( OrtError :: GetDimensionsCount ) ?;
571
- assert_ne ! ( num_dims, 0 ) ;
582
+ ( num_dims != 0 )
583
+ . then ( || ( ) )
584
+ . ok_or ( OrtError :: InvalidDimensions ) ?;
572
585
573
586
let mut node_dims: Vec < i64 > = vec ! [ 0 ; num_dims as usize ] ;
574
587
let status = g_ort ( ) . GetDimensions . unwrap ( ) (
@@ -603,8 +616,10 @@ mod dangerous {
603
616
let mut num_nodes: usize = 0 ;
604
617
let status = unsafe { f ( session_ptr, & mut num_nodes) } ;
605
618
status_to_result ( status) . map_err ( OrtError :: InOutCount ) ?;
606
- assert_eq ! ( status, std:: ptr:: null_mut( ) ) ;
607
- assert_ne ! ( num_nodes, 0 ) ;
619
+ assert_null_pointer ( status, "SessionStatus" ) ?;
620
+ ( num_nodes != 0 ) . then ( || ( ) ) . ok_or_else ( || {
621
+ OrtError :: InOutCount ( OrtApiError :: Msg ( "No nodes in model" . to_owned ( ) ) )
622
+ } ) ?;
608
623
Ok ( num_nodes)
609
624
}
610
625
@@ -641,7 +656,7 @@ mod dangerous {
641
656
642
657
let status = unsafe { f ( session_ptr, i, allocator_ptr, & mut name_bytes) } ;
643
658
status_to_result ( status) . map_err ( OrtError :: InputName ) ?;
644
- assert_ne ! ( name_bytes, std :: ptr :: null_mut ( ) ) ;
659
+ assert_not_null_pointer ( name_bytes, "InputName" ) ? ;
645
660
646
661
// FIXME: Is it safe to keep ownership of the memory?
647
662
let name = char_p_to_string ( name_bytes) ?;
@@ -692,23 +707,22 @@ mod dangerous {
692
707
693
708
let status = unsafe { f ( session_ptr, i, & mut typeinfo_ptr) } ;
694
709
status_to_result ( status) . map_err ( OrtError :: GetTypeInfo ) ?;
695
- assert_ne ! ( typeinfo_ptr, std :: ptr :: null_mut ( ) ) ;
710
+ assert_not_null_pointer ( typeinfo_ptr, "TypeInfo" ) ? ;
696
711
697
712
let mut tensor_info_ptr: * const sys:: OrtTensorTypeAndShapeInfo = std:: ptr:: null_mut ( ) ;
698
713
let status = unsafe {
699
714
g_ort ( ) . CastTypeInfoToTensorInfo . unwrap ( ) ( typeinfo_ptr, & mut tensor_info_ptr)
700
715
} ;
701
716
status_to_result ( status) . map_err ( OrtError :: CastTypeInfoToTensorInfo ) ?;
702
- assert_ne ! ( tensor_info_ptr, std :: ptr :: null_mut ( ) ) ;
717
+ assert_not_null_pointer ( tensor_info_ptr, "TensorInfo" ) ? ;
703
718
704
719
let mut type_sys = sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ;
705
720
let status =
706
721
unsafe { g_ort ( ) . GetTensorElementType . unwrap ( ) ( tensor_info_ptr, & mut type_sys) } ;
707
722
status_to_result ( status) . map_err ( OrtError :: TensorElementType ) ?;
708
- assert_ne ! (
709
- type_sys,
710
- sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED
711
- ) ;
723
+ ( type_sys != sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED )
724
+ . then ( || ( ) )
725
+ . ok_or ( OrtError :: UndefinedTensorElementType ) ?;
712
726
// This transmute should be safe since its value is read from GetTensorElementType which we must trust.
713
727
let io_type: TensorElementDataType = unsafe { std:: mem:: transmute ( type_sys) } ;
714
728
0 commit comments