Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Node trait as reusable visitor #1918

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vortex-expr/src/forms/nnf.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use vortex_error::VortexResult;

use crate::traversal::{FoldChildren, FoldDown, FoldUp, FolderMut, Node as _};
use crate::traversal::{FoldChildren, FoldDown, FoldUp, FolderMut, NodeMut};
use crate::{not, BinaryExpr, ExprRef, Not, Operator};

/// Return an equivalent expression in Negative Normal Form (NNF).
Expand Down
13 changes: 9 additions & 4 deletions vortex-expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ pub mod pruning;
mod row_filter;
mod select;
pub mod transform;
#[allow(dead_code)]
mod traversal;
pub mod traversal;

pub use binary::*;
pub use column::*;
Expand All @@ -43,7 +42,7 @@ use vortex_array::{ArrayDType as _, ArrayData, Canonical, IntoArrayData as _};
use vortex_dtype::{DType, FieldName};
use vortex_error::{VortexResult, VortexUnwrap};

use crate::traversal::{Node, ReferenceCollector};
use crate::traversal::{DynNode, Node, ReferenceCollector};

pub type ExprRef = Arc<dyn VortexExpr>;

Expand Down Expand Up @@ -86,13 +85,19 @@ pub trait VortexExprExt {

impl VortexExprExt for ExprRef {
fn references(&self) -> HashSet<FieldName> {
let mut collector = ReferenceCollector::new();
let mut collector = ReferenceCollector::default();
// The collector is infallible, so we can unwrap the result
self.accept(&mut collector).vortex_unwrap();
collector.into_fields()
}
}

impl DynNode for dyn VortexExpr {
fn arc_children(&self) -> VortexResult<Vec<&Arc<Self>>> {
Ok(self.children())
}
}

/// Splits top level and operations into separate expressions
pub fn split_conjunction(expr: &ExprRef) -> Vec<ExprRef> {
let mut conjunctions = vec![];
Expand Down
2 changes: 1 addition & 1 deletion vortex-expr/src/transform/partition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use vortex_array::aliases::hash_set::HashSet;
use vortex_dtype::{FieldName, StructDType};
use vortex_error::{VortexExpect, VortexResult};

use crate::traversal::{FoldChildren, FoldDown, FoldUp, Folder, FolderMut, Node};
use crate::traversal::{FoldChildren, FoldDown, FoldUp, Folder, FolderMut, Node, NodeMut};
use crate::{get_item, ident, pack, ExprRef, GetItem, Identity, Select, SelectField};

/// Partition an expression over the fields of the scope.
Expand Down
2 changes: 1 addition & 1 deletion vortex-expr/src/transform/remove_select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use itertools::Itertools;
use vortex_dtype::DType;
use vortex_error::{vortex_err, VortexResult};

use crate::traversal::{MutNodeVisitor, Node, TransformResult};
use crate::traversal::{MutNodeVisitor, NodeMut, TransformResult};
use crate::{get_item, pack, ExprRef, Select};

/// Select is a useful expression, however it can be defined in terms of get_item & pack,
Expand Down
2 changes: 1 addition & 1 deletion vortex-expr/src/transform/simplify.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use vortex_error::VortexResult;

use crate::traversal::{FoldChildren, FoldUp, FolderMut, Node};
use crate::traversal::{FoldChildren, FoldUp, FolderMut, NodeMut};
use crate::{get_item, ident, Column, ExprRef, GetItem, Pack};

pub fn simplify(e: ExprRef) -> VortexResult<ExprRef> {
Expand Down
28 changes: 20 additions & 8 deletions vortex-expr/src/traversal/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
mod references;
mod visitor;

use std::sync::Arc;

use itertools::Itertools;
pub use references::ReferenceCollector;
use vortex_error::VortexResult;
Expand Down Expand Up @@ -185,10 +187,16 @@ pub trait Node: Sized {

fn accept_with_context<'a, V: Folder<'a, NodeTy = Self>>(
&'a self,
visitor: &mut V,
context: V::Context,
_visitor: &mut V,
_context: V::Context,
) -> VortexResult<FoldUp<V::Out>>;
}

pub trait DynNode {
fn arc_children(&self) -> VortexResult<Vec<&Arc<Self>>>;
}

pub trait NodeMut: Sized {
fn transform<V: MutNodeVisitor<NodeTy = Self>>(
self,
_visitor: &mut V,
Expand All @@ -201,9 +209,9 @@ pub trait Node: Sized {
) -> VortexResult<FoldUp<V::Out>>;
}

impl Node for ExprRef {
impl<T: DynNode + ?Sized> Node for Arc<T> {
// A pre-order traversal.
fn accept<'a, V: NodeVisitor<'a, NodeTy = ExprRef>>(
fn accept<'a, V: NodeVisitor<'a, NodeTy = Arc<T>>>(
&'a self,
visitor: &mut V,
) -> VortexResult<TraversalOrder> {
Expand All @@ -214,7 +222,7 @@ impl Node for ExprRef {
if ord == TraversalOrder::Skip {
return Ok(TraversalOrder::Continue);
}
for child in self.children() {
for child in self.arc_children()? {
if ord != TraversalOrder::Continue {
return Ok(ord);
}
Expand All @@ -235,8 +243,8 @@ impl Node for ExprRef {
FoldDown::Stop(out) => return Ok(FoldUp::Stop(out)),
FoldDown::SkipChildren => FoldChildren::Skipped,
FoldDown::Continue(child_context) => {
let mut new_children = Vec::with_capacity(self.children().len());
for child in self.children() {
let mut new_children = Vec::with_capacity(self.arc_children()?.len());
for child in self.arc_children()? {
match child.accept_with_context(visitor, child_context.clone())? {
FoldUp::Stop(out) => return Ok(FoldUp::Stop(out)),
FoldUp::Continue(out) => new_children.push(out),
Expand All @@ -248,7 +256,9 @@ impl Node for ExprRef {

visitor.visit_up(self, context, children)
}
}

impl NodeMut for ExprRef {
// A pre-order transform, with an option to ignore sub-tress (using visit_down).
fn transform<V: MutNodeVisitor<NodeTy = Self>>(
self,
Expand Down Expand Up @@ -336,7 +346,9 @@ mod tests {
use vortex_error::VortexResult;

use crate::traversal::visitor::pre_order_visit_down;
use crate::traversal::{MutNodeVisitor, Node, NodeVisitor, TransformResult, TraversalOrder};
use crate::traversal::{
MutNodeVisitor, Node, NodeMut, NodeVisitor, TransformResult, TraversalOrder,
};
use crate::{
BinaryExpr, Column, ExprRef, FieldName, Literal, Operator, VortexExpr, VortexExprExt,
};
Expand Down
7 changes: 1 addition & 6 deletions vortex-expr/src/traversal/references.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,12 @@ use vortex_error::VortexResult;
use crate::traversal::{NodeVisitor, TraversalOrder};
use crate::{Column, ExprRef, GetItem, Select};

#[derive(Default)]
pub struct ReferenceCollector {
fields: HashSet<FieldName>,
}

impl ReferenceCollector {
pub fn new() -> Self {
Self {
fields: HashSet::new(),
}
}

pub fn with_set(set: HashSet<FieldName>) -> Self {
Self { fields: set }
}
Expand Down
2 changes: 2 additions & 0 deletions vortex-expr/src/traversal/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ where
}
}

#[allow(dead_code)]
pub fn pre_order_visit_up<'a, T: 'a + Node>(
f: impl FnMut(&'a T) -> VortexResult<TraversalOrder>,
) -> impl NodeVisitor<'a, NodeTy = T> {
Expand All @@ -47,6 +48,7 @@ pub fn pre_order_visit_up<'a, T: 'a + Node>(
}
}

#[allow(dead_code)]
pub fn pre_order_visit_down<'a, T: 'a + Node>(
f: impl FnMut(&'a T) -> VortexResult<TraversalOrder>,
) -> impl NodeVisitor<'a, NodeTy = T> {
Expand Down
4 changes: 4 additions & 0 deletions vortex-layout/src/layouts/chunked/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,8 @@ impl LayoutReader for ChunkedReader {
fn layout(&self) -> &LayoutData {
&self.layout
}

fn children(&self) -> VortexResult<Vec<&Arc<dyn LayoutReader>>> {
(0..self.nchunks()).map(|i| self.child(i)).collect()
}
}
4 changes: 4 additions & 0 deletions vortex-layout/src/layouts/flat/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,8 @@ impl LayoutReader for FlatReader {
fn layout(&self) -> &LayoutData {
&self.layout
}

fn children(&self) -> VortexResult<Vec<&Arc<dyn LayoutReader>>> {
Ok(vec![])
}
}
17 changes: 15 additions & 2 deletions vortex-layout/src/layouts/struct_/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,23 @@ impl StructReader {
.vortex_expect("Struct layout must have a struct DType, verified at construction")
}

pub(crate) fn nchildren(&self) -> usize {
self.field_readers.len()
}

/// Return the child reader for the chunk.
pub(crate) fn child(&self, name: &FieldName) -> VortexResult<&Arc<dyn LayoutReader>> {
let idx = self
.field_lookup
.as_ref()
.and_then(|lookup| lookup.get(name).copied())
.or_else(|| self.struct_dtype().find_name(name))
.ok_or_else(|| vortex_err!("Field {} not found in struct layout", name))?;
.ok_or_else(|| vortex_err!("StructReader::Field {} not found", name))?;
self.child_idx(idx)
}

// TODO: think about a Hashmap<FieldName, OnceLock<Arc<dyn LayoutReader>>> for large |fields|.
// TODO: think about a Hashmap<FieldName, OnceLock<Arc<dyn LayoutReader>>> for large |fields|.
fn child_idx(&self, idx: usize) -> VortexResult<&Arc<dyn LayoutReader>> {
self.field_readers[idx].get_or_try_init(|| {
let child_layout = self
.layout
Expand Down Expand Up @@ -110,6 +117,12 @@ impl LayoutReader for StructReader {
fn layout(&self) -> &LayoutData {
&self.layout
}

fn children(&self) -> VortexResult<Vec<&Arc<dyn LayoutReader>>> {
(0..self.nchildren())
.map(|idx| self.child_idx(idx))
.collect()
}
}

/// An expression wrapper that performs pointer equality.
Expand Down
1 change: 1 addition & 0 deletions vortex-layout/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ mod reader;
pub use reader::*;
pub mod segments;
pub mod strategies;
mod visitor;

/// The layout ID for a flat layout
pub(crate) const FLAT_LAYOUT_ID: LayoutId = LayoutId(1);
Expand Down
7 changes: 7 additions & 0 deletions vortex-layout/src/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use vortex_array::stats::{Stat, StatsSet};
use vortex_array::ArrayData;
use vortex_dtype::{DType, FieldPath};
use vortex_error::VortexResult;
use vortex_expr::traversal::DynNode;
use vortex_expr::ExprRef;
use vortex_scan::RowMask;

Expand All @@ -18,12 +19,18 @@ use crate::LayoutData;
pub trait LayoutReader: Send + Sync + ExprEvaluator + StatsEvaluator {
/// Returns the [`LayoutData`] of this reader.
fn layout(&self) -> &LayoutData;

fn children(&self) -> VortexResult<Vec<&Arc<dyn LayoutReader>>>;
}

impl LayoutReader for Arc<dyn LayoutReader + 'static> {
fn layout(&self) -> &LayoutData {
self.as_ref().layout()
}

fn children(&self) -> VortexResult<Vec<&Arc<dyn LayoutReader>>> {
self.as_ref().arc_children()
}
}

/// A trait for evaluating expressions against a [`LayoutReader`].
Expand Down
41 changes: 41 additions & 0 deletions vortex-layout/src/visitor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use std::fmt;
use std::fmt::{Debug, Formatter};
use std::sync::Arc;

use vortex_error::{VortexExpect, VortexResult};
use vortex_expr::traversal::{DynNode, Node, NodeVisitor, TraversalOrder};

use crate::LayoutReader;

impl DynNode for dyn LayoutReader {
fn arc_children(&self) -> VortexResult<Vec<&Arc<Self>>> {
self.children()
}
}

pub struct LayoutVisitor<'a, 'b> {
display: &'a mut Formatter<'b>,
}

impl NodeVisitor<'_> for LayoutVisitor<'_, '_> {
type NodeTy = Arc<dyn LayoutReader + 'static>;

fn visit_down(&mut self, node: &Self::NodeTy) -> VortexResult<TraversalOrder> {
node.layout().fmt(self.display)?;
self.display.write_str("\n")?;
Ok(TraversalOrder::Continue)
}
}

pub struct LayoutReaderDebug(pub Arc<dyn LayoutReader + 'static>);

impl Debug for LayoutReaderDebug {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
writeln!(f, "LayoutReader")?;
let mut vis = LayoutVisitor { display: f };
self.0
.accept(&mut vis)
.vortex_expect("Visitor should not fail");
Ok(())
}
}
Loading