Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DataFusion expr conversion #349

Merged
merged 1 commit into from
Jun 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
711 changes: 506 additions & 205 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -56,6 +56,8 @@ cargo_metadata = "0.18.1"
criterion = { version = "0.5.1", features = ["html_reports"] }
croaring = "1.0.1"
csv = "1.3.0"
datafusion-common = "39.0.0"
datafusion-expr = "39.0.0"
derive_builder = "0.20.0"
divan = "0.1.14"
duckdb = { version = "0.10.1", features = ["bundled"] }
8 changes: 2 additions & 6 deletions vortex-array/benches/compare.rs
Original file line number Diff line number Diff line change
@@ -27,9 +27,7 @@ fn filter_bool_indices(c: &mut Criterion) {

group.bench_function("compare_bool", |b| {
b.iter(|| {
let indices =
vortex::compute::compare::compare(&arr, &arr2, Operator::GreaterThanOrEqualTo)
.unwrap();
let indices = vortex::compute::compare::compare(&arr, &arr2, Operator::Gte).unwrap();
black_box(indices);
Ok::<(), VortexError>(())
});
@@ -53,9 +51,7 @@ fn filter_indices(c: &mut Criterion) {

group.bench_function("compare_int", |b| {
b.iter(|| {
let indices =
vortex::compute::compare::compare(&arr, &arr2, Operator::GreaterThanOrEqualTo)
.unwrap();
let indices = vortex::compute::compare::compare(&arr, &arr2, Operator::Gte).unwrap();
black_box(indices);
Ok::<(), VortexError>(())
});
24 changes: 12 additions & 12 deletions vortex-array/src/array/bool/compute/compare.rs
Original file line number Diff line number Diff line change
@@ -13,13 +13,13 @@ impl CompareFn for BoolArray {
let lhs = self.boolean_buffer();
let rhs = flattened.boolean_buffer();
let result_buf = match op {
Operator::EqualTo => lhs.bitxor(&rhs).not(),
Operator::NotEqualTo => lhs.bitxor(&rhs),
Operator::Eq => lhs.bitxor(&rhs).not(),
Operator::NotEq => lhs.bitxor(&rhs),

Operator::GreaterThan => lhs.bitand(&rhs.not()),
Operator::GreaterThanOrEqualTo => lhs.bitor(&rhs.not()),
Operator::LessThan => lhs.not().bitand(&rhs),
Operator::LessThanOrEqualTo => lhs.not().bitor(&rhs),
Operator::Gt => lhs.bitand(&rhs.not()),
Operator::Gte => lhs.bitor(&rhs.not()),
Operator::Lt => lhs.not().bitand(&rhs),
Operator::Lte => lhs.not().bitor(&rhs),
};
Ok(BoolArray::from(
self.validity()
@@ -58,10 +58,10 @@ mod test {
)
.into_array();

let matches = compare(&arr, &arr, Operator::EqualTo)?.flatten_bool()?;
let matches = compare(&arr, &arr, Operator::Eq)?.flatten_bool()?;
assert_eq!(to_int_indices(matches), [1u64, 2, 3, 4]);

let matches = compare(&arr, &arr, Operator::NotEqualTo)?.flatten_bool()?;
let matches = compare(&arr, &arr, Operator::NotEq)?.flatten_bool()?;
let empty: [u64; 0] = [];
assert_eq!(to_int_indices(matches), empty);

@@ -71,16 +71,16 @@ mod test {
)
.into_array();

let matches = compare(&arr, &other, Operator::LessThanOrEqualTo)?.flatten_bool()?;
let matches = compare(&arr, &other, Operator::Lte)?.flatten_bool()?;
assert_eq!(to_int_indices(matches), [2u64, 3, 4]);

let matches = compare(&arr, &other, Operator::LessThan)?.flatten_bool()?;
let matches = compare(&arr, &other, Operator::Lt)?.flatten_bool()?;
assert_eq!(to_int_indices(matches), [4u64]);

let matches = compare(&other, &arr, Operator::GreaterThanOrEqualTo)?.flatten_bool()?;
let matches = compare(&other, &arr, Operator::Gte)?.flatten_bool()?;
assert_eq!(to_int_indices(matches), [2u64, 3, 4]);

let matches = compare(&other, &arr, Operator::GreaterThan)?.flatten_bool()?;
let matches = compare(&other, &arr, Operator::Gt)?.flatten_bool()?;
assert_eq!(to_int_indices(matches), [4u64]);
Ok(())
}
12 changes: 6 additions & 6 deletions vortex-array/src/array/primitive/compute/compare.rs
Original file line number Diff line number Diff line change
@@ -78,10 +78,10 @@ mod test {
])
.into_array();

let matches = compare(&arr, &arr, Operator::EqualTo)?.flatten_bool()?;
let matches = compare(&arr, &arr, Operator::Eq)?.flatten_bool()?;
assert_eq!(to_int_indices(matches), [0u64, 1, 2, 3, 5, 6, 7, 8, 10]);

let matches = compare(&arr, &arr, Operator::NotEqualTo)?.flatten_bool()?;
let matches = compare(&arr, &arr, Operator::NotEq)?.flatten_bool()?;
let empty: [u64; 0] = [];
assert_eq!(to_int_indices(matches), empty);

@@ -101,16 +101,16 @@ mod test {
])
.into_array();

let matches = compare(&arr, &other, Operator::LessThanOrEqualTo)?.flatten_bool()?;
let matches = compare(&arr, &other, Operator::Lte)?.flatten_bool()?;
assert_eq!(to_int_indices(matches), [0u64, 1, 2, 3, 5, 6, 7, 8, 10]);

let matches = compare(&arr, &other, Operator::LessThan)?.flatten_bool()?;
let matches = compare(&arr, &other, Operator::Lt)?.flatten_bool()?;
assert_eq!(to_int_indices(matches), [5u64, 6, 7, 8, 10]);

let matches = compare(&other, &arr, Operator::GreaterThanOrEqualTo)?.flatten_bool()?;
let matches = compare(&other, &arr, Operator::Gte)?.flatten_bool()?;
assert_eq!(to_int_indices(matches), [0u64, 1, 2, 3, 5, 6, 7, 8, 10]);

let matches = compare(&other, &arr, Operator::GreaterThan)?.flatten_bool()?;
let matches = compare(&other, &arr, Operator::Gt)?.flatten_bool()?;
assert_eq!(to_int_indices(matches), [5u64, 6, 7, 8, 10]);
Ok(())
}
4 changes: 4 additions & 0 deletions vortex-dtype/src/field_paths.rs
Original file line number Diff line number Diff line change
@@ -24,6 +24,10 @@ impl FieldPath {
Some(Self::builder().join_all(new_field_names).build())
}
}

pub fn parts(&self) -> &[FieldIdentifier] {
&self.field_names
}
}

#[derive(Clone, Debug, PartialEq)]
6 changes: 4 additions & 2 deletions vortex-expr/Cargo.toml
Original file line number Diff line number Diff line change
@@ -15,14 +15,16 @@ rust-version = { workspace = true }
workspace = true

[dependencies]
datafusion-common = { workspace = true, optional = true }
datafusion-expr = { workspace = true, optional = true }
vortex-dtype = { path = "../vortex-dtype" }
vortex-error = { path = "../vortex-error" }
vortex-scalar = { path = "../vortex-scalar" }
serde = { workspace = true, optional = true, features = ["derive"] }


[dev-dependencies]


[features]
default = []
datafusion = ["dep:datafusion-common", "dep:datafusion-expr", "vortex-scalar/datafusion"]
serde = ["dep:serde", "vortex-dtype/serde", "vortex-scalar/serde"]
63 changes: 63 additions & 0 deletions vortex-expr/src/datafusion.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#![cfg(feature = "datafusion")]
use datafusion_common::Column;
use datafusion_expr::{BinaryExpr, Expr};
use vortex_dtype::field_paths::{FieldIdentifier, FieldPath};
use vortex_scalar::Scalar;

use crate::expressions::{Predicate, Value};
use crate::operators::Operator;

impl From<Predicate> for Expr {
fn from(value: Predicate) -> Self {
Expr::BinaryExpr(BinaryExpr::new(
Box::new(FieldPathWrapper(value.left).into()),
value.op.into(),
Box::new(value.right.into()),
))
}
}

impl From<Operator> for datafusion_expr::Operator {
fn from(value: Operator) -> Self {
match value {
Operator::Eq => datafusion_expr::Operator::Eq,
Operator::NotEq => datafusion_expr::Operator::NotEq,
Operator::Gt => datafusion_expr::Operator::Gt,
Operator::Gte => datafusion_expr::Operator::GtEq,
Operator::Lt => datafusion_expr::Operator::Lt,
Operator::Lte => datafusion_expr::Operator::LtEq,
}
}
}

impl From<Value> for Expr {
fn from(value: Value) -> Self {
match value {
Value::Field(field_path) => FieldPathWrapper(field_path).into(),
Value::Literal(literal) => ScalarWrapper(literal).into(),
}
}
}

struct FieldPathWrapper(FieldPath);
impl From<FieldPathWrapper> for Expr {
fn from(value: FieldPathWrapper) -> Self {
let mut field = String::new();
for part in value.0.parts() {
match part {
// TODO(ngates): escape quotes?
FieldIdentifier::Name(identifier) => field.push_str(&format!("\"{}\"", identifier)),
FieldIdentifier::ListIndex(idx) => field.push_str(&format!("[{}]", idx)),
}
}

Expr::Column(Column::from(field))
}
}

struct ScalarWrapper(Scalar);
impl From<ScalarWrapper> for Expr {
fn from(value: ScalarWrapper) -> Self {
Expr::Literal(value.0.into())
}
}
12 changes: 6 additions & 6 deletions vortex-expr/src/display.rs
Original file line number Diff line number Diff line change
@@ -42,12 +42,12 @@ impl Display for Value {
impl Display for Operator {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
let display = match &self {
Operator::EqualTo => "=",
Operator::NotEqualTo => "!=",
Operator::GreaterThan => ">",
Operator::GreaterThanOrEqualTo => ">=",
Operator::LessThan => "<",
Operator::LessThanOrEqualTo => "<=",
Operator::Eq => "=",
Operator::NotEq => "!=",
Operator::Gt => ">",
Operator::Gte => ">=",
Operator::Lt => "<",
Operator::Lte => "<=",
};
write!(f, "{display}")
}
14 changes: 7 additions & 7 deletions vortex-expr/src/expressions.rs
Original file line number Diff line number Diff line change
@@ -50,47 +50,47 @@ impl Value {
pub fn eq(self, field: impl Into<FieldPath>) -> Predicate {
Predicate {
left: field.into(),
op: Operator::EqualTo,
op: Operator::Eq,
right: self,
}
}

pub fn not_eq(self, field: impl Into<FieldPath>) -> Predicate {
Predicate {
left: field.into(),
op: Operator::NotEqualTo.inverse(),
op: Operator::NotEq.inverse(),
right: self,
}
}

pub fn gt(self, field: impl Into<FieldPath>) -> Predicate {
Predicate {
left: field.into(),
op: Operator::GreaterThan.inverse(),
op: Operator::Gt.inverse(),
right: self,
}
}

pub fn gte(self, field: impl Into<FieldPath>) -> Predicate {
Predicate {
left: field.into(),
op: Operator::GreaterThanOrEqualTo.inverse(),
op: Operator::Gte.inverse(),
right: self,
}
}

pub fn lt(self, field: impl Into<FieldPath>) -> Predicate {
Predicate {
left: field.into(),
op: Operator::LessThan.inverse(),
op: Operator::Lt.inverse(),
right: self,
}
}

pub fn lte(self, field: impl Into<FieldPath>) -> Predicate {
Predicate {
left: field.into(),
op: Operator::LessThanOrEqualTo.inverse(),
op: Operator::Lte.inverse(),
right: self,
}
}
@@ -109,7 +109,7 @@ mod test {
let field = field("id");
let expr = Predicate {
left: field,
op: Operator::EqualTo,
op: Operator::Eq,
right: value,
};
assert_eq!(format!("{}", expr), "($id = 1)");
12 changes: 6 additions & 6 deletions vortex-expr/src/field_paths.rs
Original file line number Diff line number Diff line change
@@ -17,47 +17,47 @@ impl FieldPathOperations for FieldPath {
fn eq(self, other: Value) -> Predicate {
Predicate {
left: self,
op: Operator::EqualTo,
op: Operator::Eq,
right: other,
}
}

fn not_eq(self, other: Value) -> Predicate {
Predicate {
left: self,
op: Operator::NotEqualTo,
op: Operator::NotEq,
right: other,
}
}

fn gt(self, other: Value) -> Predicate {
Predicate {
left: self,
op: Operator::GreaterThan,
op: Operator::Gt,
right: other,
}
}

fn gte(self, other: Value) -> Predicate {
Predicate {
left: self,
op: Operator::GreaterThanOrEqualTo,
op: Operator::Gte,
right: other,
}
}

fn lt(self, other: Value) -> Predicate {
Predicate {
left: self,
op: Operator::LessThan,
op: Operator::Lt,
right: other,
}
}

fn lte(self, other: Value) -> Predicate {
Predicate {
left: self,
op: Operator::LessThanOrEqualTo,
op: Operator::Lte,
right: other,
}
}
1 change: 1 addition & 0 deletions vortex-expr/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![feature(iter_intersperse)]
extern crate core;

mod datafusion;
mod display;
pub mod expressions;
pub mod field_paths;
48 changes: 24 additions & 24 deletions vortex-expr/src/operators.rs
Original file line number Diff line number Diff line change
@@ -8,25 +8,25 @@ use crate::expressions::Predicate;
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum Operator {
// comparison
EqualTo,
NotEqualTo,
GreaterThan,
GreaterThanOrEqualTo,
LessThan,
LessThanOrEqualTo,
Eq,
NotEq,
Gt,
Gte,
Lt,
Lte,
}

impl ops::Not for Predicate {
type Output = Self;

fn not(self) -> Self::Output {
let inverse_op = match self.op {
Operator::EqualTo => Operator::NotEqualTo,
Operator::NotEqualTo => Operator::EqualTo,
Operator::GreaterThan => Operator::LessThanOrEqualTo,
Operator::GreaterThanOrEqualTo => Operator::LessThan,
Operator::LessThan => Operator::GreaterThanOrEqualTo,
Operator::LessThanOrEqualTo => Operator::GreaterThan,
Operator::Eq => Operator::NotEq,
Operator::NotEq => Operator::Eq,
Operator::Gt => Operator::Lte,
Operator::Gte => Operator::Lt,
Operator::Lt => Operator::Gte,
Operator::Lte => Operator::Gt,
};
Predicate {
left: self.left,
@@ -39,23 +39,23 @@ impl ops::Not for Predicate {
impl Operator {
pub fn inverse(self) -> Self {
match self {
Operator::EqualTo => Operator::NotEqualTo,
Operator::NotEqualTo => Operator::EqualTo,
Operator::GreaterThan => Operator::LessThanOrEqualTo,
Operator::GreaterThanOrEqualTo => Operator::LessThan,
Operator::LessThan => Operator::GreaterThanOrEqualTo,
Operator::LessThanOrEqualTo => Operator::GreaterThan,
Operator::Eq => Operator::NotEq,
Operator::NotEq => Operator::Eq,
Operator::Gt => Operator::Lte,
Operator::Gte => Operator::Lt,
Operator::Lt => Operator::Gte,
Operator::Lte => Operator::Gt,
}
}

pub fn to_predicate<T: NativePType>(&self) -> fn(&T, &T) -> bool {
match self {
Operator::EqualTo => PartialEq::eq,
Operator::NotEqualTo => PartialEq::ne,
Operator::GreaterThan => PartialOrd::gt,
Operator::GreaterThanOrEqualTo => PartialOrd::ge,
Operator::LessThan => PartialOrd::lt,
Operator::LessThanOrEqualTo => PartialOrd::le,
Operator::Eq => PartialEq::eq,
Operator::NotEq => PartialEq::ne,
Operator::Gt => PartialOrd::gt,
Operator::Gte => PartialOrd::ge,
Operator::Lt => PartialOrd::lt,
Operator::Lte => PartialOrd::le,
}
}
}
2 changes: 2 additions & 0 deletions vortex-scalar/Cargo.toml
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@ edition = { workspace = true }
rust-version = { workspace = true }

[dependencies]
datafusion-common = { workspace = true, optional = true }
flatbuffers = { workspace = true, optional = true }
flexbuffers = { workspace = true, optional = true }
itertools = { workspace = true }
@@ -34,6 +35,7 @@ workspace = true
[features]
# Uncomment for improved IntelliJ support
# default = ["flatbuffers", "proto", "serde"]
datafusion = ["dep:datafusion-common"]
flatbuffers = [
"dep:flatbuffers",
"dep:flexbuffers",
68 changes: 68 additions & 0 deletions vortex-scalar/src/datafusion.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#![cfg(feature = "datafusion")]
use datafusion_common::ScalarValue;
use vortex_dtype::{DType, PType};

use crate::{PValue, Scalar};

impl From<Scalar> for ScalarValue {
fn from(value: Scalar) -> Self {
match value.dtype {
DType::Null => ScalarValue::Null,
DType::Bool(_) => ScalarValue::Boolean(value.value.as_bool().expect("should be bool")),
DType::Primitive(ptype, _) => {
let pvalue = value.value.as_pvalue().expect("should be pvalue");
match pvalue {
None => match ptype {
PType::U8 => ScalarValue::UInt8(None),
PType::U16 => ScalarValue::UInt16(None),
PType::U32 => ScalarValue::UInt32(None),
PType::U64 => ScalarValue::UInt64(None),
PType::I8 => ScalarValue::Int8(None),
PType::I16 => ScalarValue::Int16(None),
PType::I32 => ScalarValue::Int32(None),
PType::I64 => ScalarValue::Int64(None),
PType::F16 => ScalarValue::Float16(None),
PType::F32 => ScalarValue::Float32(None),
PType::F64 => ScalarValue::Float64(None),
},
Some(pvalue) => match pvalue {
PValue::U8(v) => ScalarValue::UInt8(Some(v)),
PValue::U16(v) => ScalarValue::UInt16(Some(v)),
PValue::U32(v) => ScalarValue::UInt32(Some(v)),
PValue::U64(v) => ScalarValue::UInt64(Some(v)),
PValue::I8(v) => ScalarValue::Int8(Some(v)),
PValue::I16(v) => ScalarValue::Int16(Some(v)),
PValue::I32(v) => ScalarValue::Int32(Some(v)),
PValue::I64(v) => ScalarValue::Int64(Some(v)),
PValue::F16(v) => ScalarValue::Float16(Some(v)),
PValue::F32(v) => ScalarValue::Float32(Some(v)),
PValue::F64(v) => ScalarValue::Float64(Some(v)),
},
}
}
DType::Utf8(_) => ScalarValue::Utf8(
value
.value
.as_buffer_string()
.expect("should be buffer string")
.map(|b| b.as_str().to_string()),
),
DType::Binary(_) => ScalarValue::Binary(
value
.value
.as_buffer()
.expect("should be buffer")
.map(|b| b.as_slice().to_vec()),
),
DType::Struct(..) => {
todo!("struct scalar conversion")
}
DType::List(..) => {
todo!("list scalar conversion")
}
DType::Extension(..) => {
todo!("extension scalar conversion")
}
}
}
}
1 change: 1 addition & 0 deletions vortex-scalar/src/lib.rs
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@ use vortex_dtype::DType;

mod binary;
mod bool;
mod datafusion;
mod display;
mod extension;
mod list;