@@ -343,8 +343,8 @@ pub enum TensorElementDataType {
343
343
Int32 = sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 as OnnxEnumInt ,
344
344
/// Signed 64-bit int, equivalent to Rust's `i64`
345
345
Int64 = sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 as OnnxEnumInt ,
346
- // // / String, equivalent to Rust's `String`
347
- // String = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING as OnnxEnumInt,
346
+ /// String, equivalent to Rust's `String`
347
+ String = sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING as OnnxEnumInt ,
348
348
// /// Boolean, equivalent to Rust's `bool`
349
349
// Bool = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL as OnnxEnumInt,
350
350
// /// 16-bit floating point, equivalent to Rust's `f16`
@@ -374,9 +374,7 @@ impl Into<sys::ONNXTensorElementDataType> for TensorElementDataType {
374
374
Int16 => sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 ,
375
375
Int32 => sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 ,
376
376
Int64 => sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 ,
377
- // String => {
378
- // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING
379
- // }
377
+ String => sys:: ONNXTensorElementDataType :: ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING ,
380
378
// Bool => {
381
379
// sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL
382
380
// }
@@ -402,15 +400,22 @@ impl Into<sys::ONNXTensorElementDataType> for TensorElementDataType {
402
400
/// Trait used to map Rust types (for example `f32`) to ONNX types (for example `Float`)
403
401
pub trait TypeToTensorElementDataType {
404
402
/// Return the ONNX type for a Rust type
405
- fn tensor_element_data_type ( ) -> sys:: ONNXTensorElementDataType ;
403
+ fn tensor_element_data_type ( ) -> TensorElementDataType ;
404
+
405
+ /// If the type is `String`, returns `Some` with utf8 contents, else `None`.
406
+ fn try_utf8_bytes ( & self ) -> Option < & [ u8 ] > ;
406
407
}
407
408
408
409
macro_rules! impl_type_trait {
409
410
( $type_: ty, $variant: ident) => {
410
411
impl TypeToTensorElementDataType for $type_ {
411
- fn tensor_element_data_type( ) -> sys :: ONNXTensorElementDataType {
412
+ fn tensor_element_data_type( ) -> TensorElementDataType {
412
413
// unsafe { std::mem::transmute(TensorElementDataType::$variant) }
413
- TensorElementDataType :: $variant. into( )
414
+ TensorElementDataType :: $variant
415
+ }
416
+
417
+ fn try_utf8_bytes( & self ) -> Option <& [ u8 ] > {
418
+ None
414
419
}
415
420
}
416
421
} ;
@@ -423,7 +428,6 @@ impl_type_trait!(u16, Uint16);
423
428
impl_type_trait ! ( i16 , Int16 ) ;
424
429
impl_type_trait ! ( i32 , Int32 ) ;
425
430
impl_type_trait ! ( i64 , Int64 ) ;
426
- // impl_type_trait!(String, String);
427
431
// impl_type_trait!(bool, Bool);
428
432
// impl_type_trait!(f16, Float16);
429
433
impl_type_trait ! ( f64 , Double ) ;
@@ -433,6 +437,39 @@ impl_type_trait!(u64, Uint64);
433
437
// impl_type_trait!(, Complex128);
434
438
// impl_type_trait!(, Bfloat16);
435
439
440
+ /// Adapter for common Rust string types to Onnx strings.
441
+ ///
442
+ /// It should be easy to use both `String` and `&str` as [TensorElementDataType::String] data, but
443
+ /// we can't define an automatic implementation for anything that implements `AsRef<str>` as it
444
+ /// would conflict with the implementations of [TypeToTensorElementDataType] for primitive numeric
445
+ /// types (which might implement `AsRef<str>` at some point in the future).
446
+ pub trait Utf8Data {
447
+ /// Returns the utf8 contents.
448
+ fn utf8_bytes ( & self ) -> & [ u8 ] ;
449
+ }
450
+
451
+ impl Utf8Data for String {
452
+ fn utf8_bytes ( & self ) -> & [ u8 ] {
453
+ self . as_bytes ( )
454
+ }
455
+ }
456
+
457
+ impl < ' a > Utf8Data for & ' a str {
458
+ fn utf8_bytes ( & self ) -> & [ u8 ] {
459
+ self . as_bytes ( )
460
+ }
461
+ }
462
+
463
+ impl < T : Utf8Data > TypeToTensorElementDataType for T {
464
+ fn tensor_element_data_type ( ) -> TensorElementDataType {
465
+ TensorElementDataType :: String
466
+ }
467
+
468
+ fn try_utf8_bytes ( & self ) -> Option < & [ u8 ] > {
469
+ Some ( self . utf8_bytes ( ) )
470
+ }
471
+ }
472
+
436
473
/// Allocator type
437
474
#[ derive( Debug , Clone ) ]
438
475
#[ repr( i32 ) ]
0 commit comments