Skip to content

Commit

Permalink
[SEDONA-708] Separate Catalog implementation into AbstractCatalog and…
Browse files Browse the repository at this point in the history
… Catalog (#1798)
  • Loading branch information
james-willis authored Feb 11, 2025
1 parent 607c383 commit 76093b5
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 51 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sedona.sql.UDF

import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionInfo, Literal}
import org.apache.spark.sql.expressions.Aggregator
import org.locationtech.jts.geom.Geometry

import scala.reflect.ClassTag

abstract class AbstractCatalog {

type FunctionDescription = (FunctionIdentifier, ExpressionInfo, FunctionBuilder)

val expressions: Seq[FunctionDescription]

val aggregateExpressions: Seq[Aggregator[Geometry, _, _]]

protected def function[T <: Expression: ClassTag](defaultArgs: Any*): FunctionDescription = {
val classTag = implicitly[ClassTag[T]]
val constructor = classTag.runtimeClass.getConstructor(classOf[Seq[Expression]])
val functionName = classTag.runtimeClass.getSimpleName
val functionIdentifier = FunctionIdentifier(functionName)
val expressionInfo = new ExpressionInfo(
classTag.runtimeClass.getCanonicalName,
functionIdentifier.database.orNull,
functionName)

def functionBuilder(expressions: Seq[Expression]): T = {
val expr = constructor.newInstance(expressions).asInstanceOf[T]
expr match {
case e: ExpectsInputTypes =>
val numParameters = e.inputTypes.size
val numArguments = expressions.size
if (numParameters == numArguments || numParameters == expr.children.size) expr
else {
val numUnspecifiedArgs = numParameters - numArguments
if (numUnspecifiedArgs > 0) {
if (numUnspecifiedArgs <= defaultArgs.size) {
val args =
expressions ++ defaultArgs.takeRight(numUnspecifiedArgs).map(Literal(_))
constructor.newInstance(args).asInstanceOf[T]
} else {
throw new IllegalArgumentException(s"function $functionName takes at least " +
s"${numParameters - defaultArgs.size} argument(s), $numArguments argument(s) specified")
}
} else {
throw new IllegalArgumentException(
s"function $functionName takes at most " +
s"$numParameters argument(s), $numArguments argument(s) specified")
}
}
case _ => expr
}
}

(functionIdentifier, expressionInfo, functionBuilder)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
*/
package org.apache.sedona.sql.UDF

import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionInfo, Literal}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.sedona_sql.expressions.{ST_InterpolatePoint, _}
import org.apache.spark.sql.sedona_sql.expressions.collect.ST_Collect
Expand All @@ -29,13 +26,10 @@ import org.locationtech.jts.geom.Geometry
import org.locationtech.jts.operation.buffer.BufferParameters

import scala.collection.mutable.ListBuffer
import scala.reflect.ClassTag

object Catalog {
object Catalog extends AbstractCatalog {

type FunctionDescription = (FunctionIdentifier, ExpressionInfo, FunctionBuilder)

val expressions: Seq[FunctionDescription] = Seq(
override val expressions: Seq[FunctionDescription] = Seq(
// Expression for vectors
function[GeometryType](),
function[ST_LabelPoint](),
Expand Down Expand Up @@ -349,52 +343,10 @@ object Catalog {
function[ST_BinaryDistanceBandColumn](),
function[ST_WeightedDistanceBandColumn]())

// Aggregate functions with Geometry as buffer
val aggregateExpressions: Seq[Aggregator[Geometry, Geometry, Geometry]] =
val aggregateExpressions: Seq[Aggregator[Geometry, _, _]] =
Seq(new ST_Envelope_Aggr, new ST_Intersection_Aggr)

// Aggregate functions with List as buffer
val aggregateExpressions2: Seq[Aggregator[Geometry, ListBuffer[Geometry], Geometry]] =
Seq(new ST_Union_Aggr())

private def function[T <: Expression: ClassTag](defaultArgs: Any*): FunctionDescription = {
val classTag = implicitly[ClassTag[T]]
val constructor = classTag.runtimeClass.getConstructor(classOf[Seq[Expression]])
val functionName = classTag.runtimeClass.getSimpleName
val functionIdentifier = FunctionIdentifier(functionName)
val expressionInfo = new ExpressionInfo(
classTag.runtimeClass.getCanonicalName,
functionIdentifier.database.orNull,
functionName)

def functionBuilder(expressions: Seq[Expression]): T = {
val expr = constructor.newInstance(expressions).asInstanceOf[T]
expr match {
case e: ExpectsInputTypes =>
val numParameters = e.inputTypes.size
val numArguments = expressions.size
if (numParameters == numArguments) expr
else {
val numUnspecifiedArgs = numParameters - numArguments
if (numUnspecifiedArgs > 0) {
if (numUnspecifiedArgs <= defaultArgs.size) {
val args =
expressions ++ defaultArgs.takeRight(numUnspecifiedArgs).map(Literal(_))
constructor.newInstance(args).asInstanceOf[T]
} else {
throw new IllegalArgumentException(s"function $functionName takes at least " +
s"${numParameters - defaultArgs.size} argument(s), $numArguments argument(s) specified")
}
} else {
throw new IllegalArgumentException(
s"function $functionName takes at most " +
s"$numParameters argument(s), $numArguments argument(s) specified")
}
}
case _ => expr
}
}

(functionIdentifier, expressionInfo, functionBuilder)
}
}

0 comments on commit 76093b5

Please sign in to comment.