Skip to content

Commit

Permalink
Extend coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
aandres committed Dec 14, 2023
1 parent fa774a4 commit 0b07bc5
Show file tree
Hide file tree
Showing 5 changed files with 292 additions and 72 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ repos:
rev: 1.7.5
hooks:
- id: bandit
files: "^ptars/.*.py$"
additional_dependencies:
- ".[toml]"
args: ["--config=pyproject.toml"]
124 changes: 62 additions & 62 deletions protos/bench.proto
Original file line number Diff line number Diff line change
Expand Up @@ -82,68 +82,68 @@ message ExampleMessage {
repeated google.type.TimeOfDay time_of_day_values = 57;
repeated google.protobuf.Empty empty_values = 58;


// Map with int32 keys
map<int32, double> double_int32_map = 88;
map<int32, float> float_int32_map = 89;
map<int32, int32> int32_int32_map = 90;
map<int32, int64> int64_int32_map = 91;
map<int32, uint32> uint32_int32_map = 92;
map<int32, uint64> uint64_int32_map = 93;
map<int32, sint32> sint32_int32_map = 94;
map<int32, sint64> sint64_int32_map = 95;
map<int32, fixed32> fixed32_int32_map = 96;
map<int32, fixed64> fixed64_int32_map = 97;
map<int32, sfixed32> sfixed32_int32_map = 98;
map<int32, sfixed64> sfixed64_int32_map = 99;
map<int32, bool> bool_int32_map = 100;
map<int32, string> string_int32_map = 101;
map<int32, bytes> bytes_int32_map = 102;
map<int32, google.protobuf.DoubleValue> wrapped_double_int32_map = 103;
map<int32, google.protobuf.FloatValue> wrapped_float_int32_map = 104;
map<int32, google.protobuf.Int32Value> wrapped_int32_int32_map = 105;
map<int32, google.protobuf.Int64Value> wrapped_int64_int32_map = 106;
map<int32, google.protobuf.UInt32Value> wrapped_uint32_int32_map = 107;
map<int32, google.protobuf.UInt64Value> wrapped_uint64_int32_map = 108;
map<int32, google.protobuf.BoolValue> wrapped_bool_int32_map = 109;
map<int32, google.protobuf.StringValue> wrapped_string_int32_map = 110;
map<int32, google.protobuf.BytesValue> wrapped_bytes_int32_map = 111;
map<int32, ExampleEnum> example_enum_int32_map = 112;
map<int32, google.protobuf.Timestamp> timestamp_int32_map = 113;
map<int32, google.type.Date> date_int32_map = 114;
map<int32, google.type.TimeOfDay> time_of_day_int32_map = 115;
map<int32, google.protobuf.Empty> empty_int32_map = 116;

// Map with string keys
map<string, double> double_string_map = 117;
map<string, float> float_string_map = 118;
map<string, int32> int32_string_map = 119;
map<string, int64> int64_string_map = 120;
map<string, uint32> uint32_string_map = 121;
map<string, uint64> uint64_string_map = 122;
map<string, sint32> sint32_string_map = 123;
map<string, sint64> sint64_string_map = 124;
map<string, fixed32> fixed32_string_map = 125;
map<string, fixed64> fixed64_string_map = 126;
map<string, sfixed32> sfixed32_string_map = 127;
map<string, sfixed64> sfixed64_string_map = 128;
map<string, bool> bool_string_map = 129;
map<string, string> string_string_map = 130;
map<string, bytes> bytes_string_map = 131;
map<string, google.protobuf.DoubleValue> wrapped_double_string_map = 132;
map<string, google.protobuf.FloatValue> wrapped_float_string_map = 133;
map<string, google.protobuf.Int32Value> wrapped_int32_string_map = 134;
map<string, google.protobuf.Int64Value> wrapped_int64_string_map = 135;
map<string, google.protobuf.UInt32Value> wrapped_uint32_string_map = 136;
map<string, google.protobuf.UInt64Value> wrapped_uint64_string_map = 137;
map<string, google.protobuf.BoolValue> wrapped_bool_string_map = 138;
map<string, google.protobuf.StringValue> wrapped_string_string_map = 139;
map<string, google.protobuf.BytesValue> wrapped_bytes_string_map = 140;
map<string, ExampleEnum> example_enum_string_map = 141;
map<string, google.protobuf.Timestamp> timestamp_string_map = 142;
map<string, google.type.Date> date_string_map = 143;
map<string, google.type.TimeOfDay> time_of_day_string_map = 144;
map<string, google.protobuf.Empty> empty_string_map = 145;
//
// // Map with int32 keys
// map<int32, double> double_int32_map = 88;
// map<int32, float> float_int32_map = 89;
// map<int32, int32> int32_int32_map = 90;
// map<int32, int64> int64_int32_map = 91;
// map<int32, uint32> uint32_int32_map = 92;
// map<int32, uint64> uint64_int32_map = 93;
// map<int32, sint32> sint32_int32_map = 94;
// map<int32, sint64> sint64_int32_map = 95;
// map<int32, fixed32> fixed32_int32_map = 96;
// map<int32, fixed64> fixed64_int32_map = 97;
// map<int32, sfixed32> sfixed32_int32_map = 98;
// map<int32, sfixed64> sfixed64_int32_map = 99;
// map<int32, bool> bool_int32_map = 100;
// map<int32, string> string_int32_map = 101;
// map<int32, bytes> bytes_int32_map = 102;
// map<int32, google.protobuf.DoubleValue> wrapped_double_int32_map = 103;
// map<int32, google.protobuf.FloatValue> wrapped_float_int32_map = 104;
// map<int32, google.protobuf.Int32Value> wrapped_int32_int32_map = 105;
// map<int32, google.protobuf.Int64Value> wrapped_int64_int32_map = 106;
// map<int32, google.protobuf.UInt32Value> wrapped_uint32_int32_map = 107;
// map<int32, google.protobuf.UInt64Value> wrapped_uint64_int32_map = 108;
// map<int32, google.protobuf.BoolValue> wrapped_bool_int32_map = 109;
// map<int32, google.protobuf.StringValue> wrapped_string_int32_map = 110;
// map<int32, google.protobuf.BytesValue> wrapped_bytes_int32_map = 111;
// map<int32, ExampleEnum> example_enum_int32_map = 112;
// map<int32, google.protobuf.Timestamp> timestamp_int32_map = 113;
// map<int32, google.type.Date> date_int32_map = 114;
// map<int32, google.type.TimeOfDay> time_of_day_int32_map = 115;
// map<int32, google.protobuf.Empty> empty_int32_map = 116;
//
// // Map with string keys
// map<string, double> double_string_map = 117;
// map<string, float> float_string_map = 118;
// map<string, int32> int32_string_map = 119;
// map<string, int64> int64_string_map = 120;
// map<string, uint32> uint32_string_map = 121;
// map<string, uint64> uint64_string_map = 122;
// map<string, sint32> sint32_string_map = 123;
// map<string, sint64> sint64_string_map = 124;
// map<string, fixed32> fixed32_string_map = 125;
// map<string, fixed64> fixed64_string_map = 126;
// map<string, sfixed32> sfixed32_string_map = 127;
// map<string, sfixed64> sfixed64_string_map = 128;
// map<string, bool> bool_string_map = 129;
// map<string, string> string_string_map = 130;
// map<string, bytes> bytes_string_map = 131;
// map<string, google.protobuf.DoubleValue> wrapped_double_string_map = 132;
// map<string, google.protobuf.FloatValue> wrapped_float_string_map = 133;
// map<string, google.protobuf.Int32Value> wrapped_int32_string_map = 134;
// map<string, google.protobuf.Int64Value> wrapped_int64_string_map = 135;
// map<string, google.protobuf.UInt32Value> wrapped_uint32_string_map = 136;
// map<string, google.protobuf.UInt64Value> wrapped_uint64_string_map = 137;
// map<string, google.protobuf.BoolValue> wrapped_bool_string_map = 138;
// map<string, google.protobuf.StringValue> wrapped_string_string_map = 139;
// map<string, google.protobuf.BytesValue> wrapped_bytes_string_map = 140;
// map<string, ExampleEnum> example_enum_string_map = 141;
// map<string, google.protobuf.Timestamp> timestamp_string_map = 142;
// map<string, google.type.Date> date_string_map = 143;
// map<string, google.type.TimeOfDay> time_of_day_string_map = 144;
// map<string, google.protobuf.Empty> empty_string_map = 145;
}

message NestedExampleMessage {
Expand Down
33 changes: 24 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use std::collections::HashMap;
use std::sync::Arc;

use arrow::array::ArrayData;
use arrow::buffer::Buffer;
use arrow::array::{ArrayData};
use arrow::buffer::{Buffer, NullBuffer};
use arrow::datatypes::{ArrowNativeType, ToByteSlice};
use arrow::pyarrow::ToPyArrow;
use arrow::record_batch::RecordBatch;
Expand Down Expand Up @@ -221,10 +221,13 @@ fn nested_messages_to_array(
);
is_valid.push(field.has_field(message.as_ref()));
}
println!("!!! {}", message_descriptor.full_name().to_string());

let arrays = fields_to_arrays(&nested_messages, message_descriptor);
return Arc::new(StructArray::from((arrays, Buffer::from_iter(is_valid))));
return if arrays.is_empty() {
Arc::new(StructArray::new_empty_fields(nested_messages.len(), Some(NullBuffer::from_iter(is_valid))))
}
else {
Arc::new(StructArray::from((arrays, Buffer::from_iter(is_valid))))
}
}

fn read_primitive<'b, T: Clone, A: From<Vec<T>> + Array + 'static>(
Expand Down Expand Up @@ -376,6 +379,7 @@ fn repeated_field_to_array(
DataType::Int32,
&ReflectValueRef::to_enum_value,
),

RuntimeType::Message(message_descriptor) => {
let mut repeated_messages: Vec<Box<dyn MessageDyn>> = Vec::new();
let mut offsets: Vec<i32> = Vec::new();
Expand All @@ -390,7 +394,13 @@ fn repeated_field_to_array(
}
offsets.push(i32::from_usize(repeated_messages.len()).unwrap());
}
let struct_array = Arc::new(StructArray::from(fields_to_arrays(&repeated_messages, message_descriptor)));

let arrays = fields_to_arrays(&repeated_messages, message_descriptor);
let struct_array: Arc<StructArray> = if arrays.is_empty() {
Arc::new(StructArray::new_empty_fields(repeated_messages.len(), None))
} else {
Arc::new(StructArray::from(arrays))
};
let list_data_type =
DataType::List(Arc::new(Field::new("item", struct_array.data_type().clone(), false)));
let list_data: ArrayData = ArrayData::builder(list_data_type)
Expand All @@ -411,7 +421,7 @@ fn field_to_array(
return match field.runtime_field_type() {
RuntimeFieldType::Singular(x) => singular_field_to_array(field, &x, messages),
RuntimeFieldType::Repeated(x) => repeated_field_to_array(field, &x, messages),
RuntimeFieldType::Map(_, _) => Err("repeated not supported"),
RuntimeFieldType::Map(_, _) => Err("map not supported"),
};
}

Expand Down Expand Up @@ -459,7 +469,7 @@ fn fields_to_arrays(

#[pymethods]
impl MessageHandler {
fn list_to_table(&self, values: Vec<Vec<u8>>, py: Python<'_>) -> PyResult<PyObject> {
fn list_to_record_batch(&self, values: Vec<Vec<u8>>, py: Python<'_>) -> PyResult<PyObject> {
let messages: Vec<Box<dyn MessageDyn>> = values
.iter()
.map(|x| {
Expand All @@ -470,7 +480,12 @@ impl MessageHandler {
.collect();
let arrays: Vec<(Arc<Field>, Arc<dyn Array>)> =
fields_to_arrays(&messages, &self.message_descriptor);
let batch = RecordBatch::from(StructArray::from(arrays));
let struct_array = if arrays.is_empty() {
StructArray::new_empty_fields(messages.len(), None)
} else {
StructArray::from(arrays)
};
let batch = RecordBatch::from(StructArray::from(struct_array));
return batch.to_pyarrow(py);
}
}
Expand Down
Loading

0 comments on commit 0b07bc5

Please sign in to comment.