diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/AbstractCatalog.scala b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/AbstractCatalog.scala new file mode 100644 index 0000000000..3ad579c38c --- /dev/null +++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/AbstractCatalog.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.sql.UDF + +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionInfo, Literal} +import org.apache.spark.sql.expressions.Aggregator +import org.locationtech.jts.geom.Geometry + +import scala.reflect.ClassTag + +abstract class AbstractCatalog { + + type FunctionDescription = (FunctionIdentifier, ExpressionInfo, FunctionBuilder) + + val expressions: Seq[FunctionDescription] + + val aggregateExpressions: Seq[Aggregator[Geometry, _, _]] + + protected def function[T <: Expression: ClassTag](defaultArgs: Any*): FunctionDescription = { + val classTag = implicitly[ClassTag[T]] + val constructor = classTag.runtimeClass.getConstructor(classOf[Seq[Expression]]) + val functionName = classTag.runtimeClass.getSimpleName + val functionIdentifier = FunctionIdentifier(functionName) + val expressionInfo = new ExpressionInfo( + classTag.runtimeClass.getCanonicalName, + functionIdentifier.database.orNull, + functionName) + + def functionBuilder(expressions: Seq[Expression]): T = { + val expr = constructor.newInstance(expressions).asInstanceOf[T] + expr match { + case e: ExpectsInputTypes => + val numParameters = e.inputTypes.size + val numArguments = expressions.size + if (numParameters == numArguments || numParameters == expr.children.size) expr + else { + val numUnspecifiedArgs = numParameters - numArguments + if (numUnspecifiedArgs > 0) { + if (numUnspecifiedArgs <= defaultArgs.size) { + val args = + expressions ++ defaultArgs.takeRight(numUnspecifiedArgs).map(Literal(_)) + constructor.newInstance(args).asInstanceOf[T] + } else { + throw new IllegalArgumentException(s"function $functionName takes at least " + + s"${numParameters - defaultArgs.size} argument(s), $numArguments argument(s) specified") + } + } else { + throw new IllegalArgumentException( + s"function $functionName takes at most " + + s"$numParameters argument(s), $numArguments argument(s) specified") + } + } + case _ => expr + } + } + + (functionIdentifier, expressionInfo, functionBuilder) + } +} diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala index 0bffa54baf..16c393cdbc 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala +++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala @@ -18,9 +18,6 @@ */ package org.apache.sedona.sql.UDF -import org.apache.spark.sql.catalyst.FunctionIdentifier -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionInfo, Literal} import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.sedona_sql.expressions.{ST_InterpolatePoint, _} import org.apache.spark.sql.sedona_sql.expressions.collect.ST_Collect @@ -29,13 +26,10 @@ import org.locationtech.jts.geom.Geometry import org.locationtech.jts.operation.buffer.BufferParameters import scala.collection.mutable.ListBuffer -import scala.reflect.ClassTag -object Catalog { +object Catalog extends AbstractCatalog { - type FunctionDescription = (FunctionIdentifier, ExpressionInfo, FunctionBuilder) - - val expressions: Seq[FunctionDescription] = Seq( + override val expressions: Seq[FunctionDescription] = Seq( // Expression for vectors function[GeometryType](), function[ST_LabelPoint](), @@ -349,52 +343,10 @@ object Catalog { function[ST_BinaryDistanceBandColumn](), function[ST_WeightedDistanceBandColumn]()) - // Aggregate functions with Geometry as buffer - val aggregateExpressions: Seq[Aggregator[Geometry, Geometry, Geometry]] = + val aggregateExpressions: Seq[Aggregator[Geometry, _, _]] = Seq(new ST_Envelope_Aggr, new ST_Intersection_Aggr) // Aggregate functions with List as buffer val aggregateExpressions2: Seq[Aggregator[Geometry, ListBuffer[Geometry], Geometry]] = Seq(new ST_Union_Aggr()) - - private def function[T <: Expression: ClassTag](defaultArgs: Any*): FunctionDescription = { - val classTag = implicitly[ClassTag[T]] - val constructor = classTag.runtimeClass.getConstructor(classOf[Seq[Expression]]) - val functionName = classTag.runtimeClass.getSimpleName - val functionIdentifier = FunctionIdentifier(functionName) - val expressionInfo = new ExpressionInfo( - classTag.runtimeClass.getCanonicalName, - functionIdentifier.database.orNull, - functionName) - - def functionBuilder(expressions: Seq[Expression]): T = { - val expr = constructor.newInstance(expressions).asInstanceOf[T] - expr match { - case e: ExpectsInputTypes => - val numParameters = e.inputTypes.size - val numArguments = expressions.size - if (numParameters == numArguments) expr - else { - val numUnspecifiedArgs = numParameters - numArguments - if (numUnspecifiedArgs > 0) { - if (numUnspecifiedArgs <= defaultArgs.size) { - val args = - expressions ++ defaultArgs.takeRight(numUnspecifiedArgs).map(Literal(_)) - constructor.newInstance(args).asInstanceOf[T] - } else { - throw new IllegalArgumentException(s"function $functionName takes at least " + - s"${numParameters - defaultArgs.size} argument(s), $numArguments argument(s) specified") - } - } else { - throw new IllegalArgumentException( - s"function $functionName takes at most " + - s"$numParameters argument(s), $numArguments argument(s) specified") - } - } - case _ => expr - } - } - - (functionIdentifier, expressionInfo, functionBuilder) - } }