@@ -322,154 +322,6 @@ impl Into<sys::GraphOptimizationLevel> for GraphOptimizationLevel {
322
322
}
323
323
}
324
324
325
- // FIXME: Use https://docs.rs/bindgen/0.54.1/bindgen/struct.Builder.html#method.rustified_enum
326
- // FIXME: Add tests to cover the commented out types
327
- /// Enum mapping ONNX Runtime's supported tensor types
328
- #[ derive( Debug ) ]
329
- #[ cfg_attr( not( windows) , repr( u32 ) ) ]
330
- #[ cfg_attr( windows, repr( i32 ) ) ]
331
- pub enum TensorElementDataType {
332
- /// 32-bit floating point, equivalent to Rust's `f32`
333
- Float = sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT as OnnxEnumInt ,
334
- /// Unsigned 8-bit int, equivalent to Rust's `u8`
335
- Uint8 = sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 as OnnxEnumInt ,
336
- /// Signed 8-bit int, equivalent to Rust's `i8`
337
- Int8 = sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 as OnnxEnumInt ,
338
- /// Unsigned 16-bit int, equivalent to Rust's `u16`
339
- Uint16 = sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 as OnnxEnumInt ,
340
- /// Signed 16-bit int, equivalent to Rust's `i16`
341
- Int16 = sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 as OnnxEnumInt ,
342
- /// Signed 32-bit int, equivalent to Rust's `i32`
343
- Int32 = sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 as OnnxEnumInt ,
344
- /// Signed 64-bit int, equivalent to Rust's `i64`
345
- Int64 = sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 as OnnxEnumInt ,
346
- /// String, equivalent to Rust's `String`
347
- String = sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING as OnnxEnumInt ,
348
- // /// Boolean, equivalent to Rust's `bool`
349
- // Bool = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL as OnnxEnumInt,
350
- // /// 16-bit floating point, equivalent to Rust's `f16`
351
- // Float16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 as OnnxEnumInt,
352
- /// 64-bit floating point, equivalent to Rust's `f64`
353
- Double = sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE as OnnxEnumInt ,
354
- /// Unsigned 32-bit int, equivalent to Rust's `u32`
355
- Uint32 = sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 as OnnxEnumInt ,
356
- /// Unsigned 64-bit int, equivalent to Rust's `u64`
357
- Uint64 = sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 as OnnxEnumInt ,
358
- // /// Complex 64-bit floating point, equivalent to Rust's `???`
359
- // Complex64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 as OnnxEnumInt,
360
- // /// Complex 128-bit floating point, equivalent to Rust's `???`
361
- // Complex128 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 as OnnxEnumInt,
362
- // /// Brain 16-bit floating point
363
- // Bfloat16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 as OnnxEnumInt,
364
- }
365
-
366
- impl Into < sys:: ONNXTensorElementDataType > for TensorElementDataType {
367
- fn into ( self ) -> sys:: ONNXTensorElementDataType {
368
- use TensorElementDataType :: * ;
369
- match self {
370
- Float => sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT ,
371
- Uint8 => sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 ,
372
- Int8 => sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 ,
373
- Uint16 => sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 ,
374
- Int16 => sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 ,
375
- Int32 => sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 ,
376
- Int64 => sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 ,
377
- String => sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING ,
378
- // Bool => {
379
- // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL
380
- // }
381
- // Float16 => {
382
- // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16
383
- // }
384
- Double => sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE ,
385
- Uint32 => sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 ,
386
- Uint64 => sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 ,
387
- // Complex64 => {
388
- // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64
389
- // }
390
- // Complex128 => {
391
- // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128
392
- // }
393
- // Bfloat16 => {
394
- // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16
395
- // }
396
- }
397
- }
398
- }
399
-
400
- /// Trait used to map Rust types (for example `f32`) to ONNX types (for example `Float`)
401
- pub trait TypeToTensorElementDataType {
402
- /// Return the ONNX type for a Rust type
403
- fn tensor_element_data_type ( ) -> TensorElementDataType ;
404
-
405
- /// If the type is `String`, returns `Some` with utf8 contents, else `None`.
406
- fn try_utf8_bytes ( & self ) -> Option < & [ u8 ] > ;
407
- }
408
-
409
- macro_rules! impl_type_trait {
410
- ( $type_: ty, $variant: ident) => {
411
- impl TypeToTensorElementDataType for $type_ {
412
- fn tensor_element_data_type( ) -> TensorElementDataType {
413
- // unsafe { std::mem::transmute(TensorElementDataType::$variant) }
414
- TensorElementDataType :: $variant
415
- }
416
-
417
- fn try_utf8_bytes( & self ) -> Option <& [ u8 ] > {
418
- None
419
- }
420
- }
421
- } ;
422
- }
423
-
424
- impl_type_trait ! ( f32 , Float ) ;
425
- impl_type_trait ! ( u8 , Uint8 ) ;
426
- impl_type_trait ! ( i8 , Int8 ) ;
427
- impl_type_trait ! ( u16 , Uint16 ) ;
428
- impl_type_trait ! ( i16 , Int16 ) ;
429
- impl_type_trait ! ( i32 , Int32 ) ;
430
- impl_type_trait ! ( i64 , Int64 ) ;
431
- // impl_type_trait!(bool, Bool);
432
- // impl_type_trait!(f16, Float16);
433
- impl_type_trait ! ( f64 , Double ) ;
434
- impl_type_trait ! ( u32 , Uint32 ) ;
435
- impl_type_trait ! ( u64 , Uint64 ) ;
436
- // impl_type_trait!(, Complex64);
437
- // impl_type_trait!(, Complex128);
438
- // impl_type_trait!(, Bfloat16);
439
-
440
- /// Adapter for common Rust string types to Onnx strings.
441
- ///
442
- /// It should be easy to use both `String` and `&str` as [TensorElementDataType::String] data, but
443
- /// we can't define an automatic implementation for anything that implements `AsRef<str>` as it
444
- /// would conflict with the implementations of [TypeToTensorElementDataType] for primitive numeric
445
- /// types (which might implement `AsRef<str>` at some point in the future).
446
- pub trait Utf8Data {
447
- /// Returns the utf8 contents.
448
- fn utf8_bytes ( & self ) -> & [ u8 ] ;
449
- }
450
-
451
- impl Utf8Data for String {
452
- fn utf8_bytes ( & self ) -> & [ u8 ] {
453
- self . as_bytes ( )
454
- }
455
- }
456
-
457
- impl < ' a > Utf8Data for & ' a str {
458
- fn utf8_bytes ( & self ) -> & [ u8 ] {
459
- self . as_bytes ( )
460
- }
461
- }
462
-
463
- impl < T : Utf8Data > TypeToTensorElementDataType for T {
464
- fn tensor_element_data_type ( ) -> TensorElementDataType {
465
- TensorElementDataType :: String
466
- }
467
-
468
- fn try_utf8_bytes ( & self ) -> Option < & [ u8 ] > {
469
- Some ( self . utf8_bytes ( ) )
470
- }
471
- }
472
-
473
325
/// Allocator type
474
326
#[ derive( Debug , Clone ) ]
475
327
#[ repr( i32 ) ]
0 commit comments