diff --git a/vortex-expr/src/transform/immediate_access.rs b/vortex-expr/src/transform/immediate_access.rs new file mode 100644 index 0000000000..aba6360ab8 --- /dev/null +++ b/vortex-expr/src/transform/immediate_access.rs @@ -0,0 +1,88 @@ +use itertools::Itertools; +use vortex_array::aliases::hash_map::HashMap; +use vortex_array::aliases::hash_set::HashSet; +use vortex_dtype::{FieldName, StructDType}; +use vortex_error::VortexResult; + +use crate::traversal::{Node, NodeVisitor, TraversalOrder}; +use crate::{ExprRef, GetItem, Identity, Select}; + +pub type FieldAccesses<'a> = HashMap<&'a ExprRef, HashSet>; + +pub fn immediate_scope_accesses<'a>( + expr: &'a ExprRef, + scope_dtype: &'a StructDType, +) -> VortexResult> { + ImmediateScopeAccessesAnalysis::<'a>::analyze(expr, scope_dtype) +} + +/// For all subexpressions in an expression, find the fields that are accessed directly from the +/// scope, but not any fields in those fields +/// e.g. scope = {a: {b: .., c: ..}, d: ..}, expr = ident().a.b + ident().d accesses {a,d} (not b). +struct ImmediateScopeAccessesAnalysis<'a> { + sub_expressions: FieldAccesses<'a>, + scope_dtype: &'a StructDType, +} + +impl<'a> ImmediateScopeAccessesAnalysis<'a> { + fn new(scope_dtype: &'a StructDType) -> Self { + Self { + sub_expressions: HashMap::new(), + scope_dtype, + } + } + + fn analyze(expr: &'a ExprRef, scope_dtype: &'a StructDType) -> VortexResult> { + let mut analysis = Self::new(scope_dtype); + expr.accept(&mut analysis)?; + Ok(analysis.sub_expressions) + } +} + +// This is a very naive, but simple analysis to find the fields that are accessed directly on an +// identity node. This is combined to provide an over-approximation of the fields that are accessed +// by an expression. +impl<'a> NodeVisitor<'a> for ImmediateScopeAccessesAnalysis<'a> { + type NodeTy = ExprRef; + + fn visit_down(&mut self, node: &'a Self::NodeTy) -> VortexResult { + assert!( + !node.as_any().is::() { - assert!(matches!(select.fields(), SelectField::Include(_))); - if select.child().as_any().is::() { - self.sub_expressions.insert( - node, - HashSet::from_iter(select.fields().fields().iter().cloned()), - ); - } - return Ok(FoldDown::SkipChildren(())); - } else if node.as_any().is::() { - let st_dtype = &self.scope_dtype; - self.sub_expressions - .insert(node, st_dtype.names().iter().cloned().collect()); - } - - Ok(FoldDown::Continue(())) - } - - fn visit_up( - &mut self, - node: &'a ExprRef, - _context: (), - _children: Vec<()>, - ) -> VortexResult> { - let accesses = node - .children() - .iter() - .filter_map(|c| self.sub_expressions.get(c).cloned()) - .collect_vec(); - - let node_accesses = self.sub_expressions.entry(node).or_default(); - accesses - .into_iter() - .for_each(|fields| node_accesses.extend(fields.iter().cloned())); - - Ok(FoldUp::Continue(())) - } -} - #[derive(Debug)] struct StructFieldExpressionSplitter<'a> { sub_expressions: HashMap>, - accesses: FieldAccesses<'a>, + accesses: &'a FieldAccesses<'a>, scope_dtype: &'a StructDType, } impl<'a> StructFieldExpressionSplitter<'a> { - fn new(accesses: FieldAccesses<'a>, scope_dtype: &'a StructDType) -> Self { + fn new(accesses: &'a FieldAccesses<'a>, scope_dtype: &'a StructDType) -> Self { Self { sub_expressions: HashMap::new(), accesses, @@ -168,16 +82,9 @@ impl<'a> StructFieldExpressionSplitter<'a> { _ => vortex_bail!("Expected a struct dtype, got {:?}", dtype), }; - let mut expr_top_level_ref = ImmediateIdentityAccessesAnalysis::new(scope_dtype); - expr.accept_with_context(&mut expr_top_level_ref, ())?; - - let expression_accesses = expr_top_level_ref - .sub_expressions - .get(&expr) - .map(|ac| ac.len()); + let field_accesses = immediate_scope_accesses(&expr, scope_dtype)?; - let mut splitter = - StructFieldExpressionSplitter::new(expr_top_level_ref.sub_expressions, scope_dtype); + let mut splitter = StructFieldExpressionSplitter::new(&field_accesses, scope_dtype); let split = expr.clone().transform_with_context(&mut splitter, ())?; @@ -212,11 +119,12 @@ impl<'a> StructFieldExpressionSplitter<'a> { }) .try_collect()?; + let expression_access_counts = field_accesses.get(&expr).map(|ac| ac.len()); // Ensure that there are not more accesses than partitions, we missed something - assert!(expression_accesses.unwrap_or(0) <= partitions.len()); + assert!(expression_access_counts.unwrap_or(0) <= partitions.len()); // Ensure that there are as many partitions as there are accesses/fields in the scope, // this will affect performance, not correctness. - debug_assert_eq!(expression_accesses.unwrap_or(0), partitions.len()); + debug_assert_eq!(expression_access_counts.unwrap_or(0), partitions.len()); let split = split .result()