Skip to content

Commit

Permalink
Type Evolution in INSERT
Browse files Browse the repository at this point in the history
  • Loading branch information
johanl-db committed Mar 22, 2024
1 parent 5f6d66a commit 5efbef6
Show file tree
Hide file tree
Showing 3 changed files with 517 additions and 24 deletions.
80 changes: 59 additions & 21 deletions spark/src/main/scala/org/apache/spark/sql/delta/DeltaAnalysis.scala
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.execution.streaming.StreamingRelation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ArrayType, DataType, IntegerType, MapType, StructField, StructType}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap

/**
Expand All @@ -81,8 +81,8 @@ class DeltaAnalysis(session: SparkSession)
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsDown {
// INSERT INTO by ordinal and df.insertInto()
case a @ AppendDelta(r, d) if !a.isByName &&
needsSchemaAdjustmentByOrdinal(d.name(), a.query, r.schema) =>
val projection = resolveQueryColumnsByOrdinal(a.query, r.output, d.name())
needsSchemaAdjustmentByOrdinal(d, a.query, r.schema) =>
val projection = resolveQueryColumnsByOrdinal(a.query, r.output, d)
if (projection != a.query) {
a.copy(query = projection)
} else {
Expand Down Expand Up @@ -208,8 +208,8 @@ class DeltaAnalysis(session: SparkSession)

// INSERT OVERWRITE by ordinal and df.insertInto()
case o @ OverwriteDelta(r, d) if !o.isByName &&
needsSchemaAdjustmentByOrdinal(d.name(), o.query, r.schema) =>
val projection = resolveQueryColumnsByOrdinal(o.query, r.output, d.name())
needsSchemaAdjustmentByOrdinal(d, o.query, r.schema) =>
val projection = resolveQueryColumnsByOrdinal(o.query, r.output, d)
if (projection != o.query) {
val aliases = AttributeMap(o.query.output.zip(projection.output).collect {
case (l: AttributeReference, r: AttributeReference) if !l.sameRef(r) => (l, r)
Expand Down Expand Up @@ -245,9 +245,9 @@ class DeltaAnalysis(session: SparkSession)
case o @ DynamicPartitionOverwriteDelta(r, d) if o.resolved
=>
val adjustedQuery = if (!o.isByName &&
needsSchemaAdjustmentByOrdinal(d.name(), o.query, r.schema)) {
needsSchemaAdjustmentByOrdinal(d, o.query, r.schema)) {
// INSERT OVERWRITE by ordinal and df.insertInto()
resolveQueryColumnsByOrdinal(o.query, r.output, d.name())
resolveQueryColumnsByOrdinal(o.query, r.output, d)
} else if (o.isByName && o.origin.sqlText.nonEmpty &&
needsSchemaAdjustmentByName(o.query, r.output, d)) {
// INSERT OVERWRITE by name
Expand Down Expand Up @@ -850,12 +850,14 @@ class DeltaAnalysis(session: SparkSession)
* type column/field.
*/
private def resolveQueryColumnsByOrdinal(
query: LogicalPlan, targetAttrs: Seq[Attribute], tblName: String): LogicalPlan = {
query: LogicalPlan, targetAttrs: Seq[Attribute], deltaTable: DeltaTableV2): LogicalPlan = {
// always add a Cast. it will be removed in the optimizer if it is unnecessary.
val project = query.output.zipWithIndex.map { case (attr, i) =>
if (i < targetAttrs.length) {
val targetAttr = targetAttrs(i)
addCastToColumn(attr, targetAttr, tblName)
addCastToColumn(attr, targetAttr, deltaTable.name(),
allowTypeWidening = allowTypeWidening(deltaTable)
)
} else {
attr
}
Expand Down Expand Up @@ -890,47 +892,69 @@ class DeltaAnalysis(session: SparkSession)
.getOrElse {
throw DeltaErrors.missingColumn(attr, targetAttrs)
}
addCastToColumn(attr, targetAttr, deltaTable.name())
addCastToColumn(attr, targetAttr, deltaTable.name(),
allowTypeWidening = allowTypeWidening(deltaTable)
)
}
Project(project, query)
}

private def addCastToColumn(
attr: Attribute,
targetAttr: Attribute,
tblName: String): NamedExpression = {
tblName: String,
allowTypeWidening: Boolean): NamedExpression = {
val expr = (attr.dataType, targetAttr.dataType) match {
case (s, t) if s == t =>
attr
case (s: StructType, t: StructType) if s != t =>
addCastsToStructs(tblName, attr, s, t)
addCastsToStructs(tblName, attr, s, t, allowTypeWidening)
case (ArrayType(s: StructType, sNull: Boolean), ArrayType(t: StructType, tNull: Boolean))
if s != t && sNull == tNull =>
addCastsToArrayStructs(tblName, attr, s, t, sNull)
addCastsToArrayStructs(tblName, attr, s, t, sNull, allowTypeWidening)
case (s: AtomicType, t: AtomicType)
if allowTypeWidening && TypeWidening.isTypeChangeSupportedForSchemaEvolution(t, s) =>
// Keep the type from the query, the target schema will be updated to widen the existing
// type to match it.
attr
case _ =>
getCastFunction(attr, targetAttr.dataType, targetAttr.name)
}
Alias(expr, targetAttr.name)(explicitMetadata = Option(targetAttr.metadata))
}

/**
* Whether inserting values that have a wider type than the table has is allowed. In that case,
* values are not downcasted to the current table type and the table schema is updated instead to
* use the wider type.
*/
private def allowTypeWidening(deltaTable: DeltaTableV2): Boolean = {
val options = new DeltaOptions(Map.empty[String, String], conf)
options.canMergeSchema && TypeWidening.isEnabled(
deltaTable.initialSnapshot.protocol,
deltaTable.initialSnapshot.metadata
)
}

/**
* With Delta, we ACCEPT_ANY_SCHEMA, meaning that Spark doesn't automatically adjust the schema
* of INSERT INTO. This allows us to perform better schema enforcement/evolution. Since Spark
* skips this step, we see if we need to perform any schema adjustment here.
*/
private def needsSchemaAdjustmentByOrdinal(
tableName: String,
deltaTable: DeltaTableV2,
query: LogicalPlan,
schema: StructType): Boolean = {
val output = query.output
if (output.length < schema.length) {
throw DeltaErrors.notEnoughColumnsInInsert(tableName, output.length, schema.length)
throw DeltaErrors.notEnoughColumnsInInsert(deltaTable.name(), output.length, schema.length)
}
// Now we should try our best to match everything that already exists, and leave the rest
// for schema evolution to WriteIntoDelta
val existingSchemaOutput = output.take(schema.length)
existingSchemaOutput.map(_.name) != schema.map(_.name) ||
!SchemaUtils.isReadCompatible(schema.asNullable, existingSchemaOutput.toStructType)
!SchemaUtils.isReadCompatible(schema.asNullable, existingSchemaOutput.toStructType,
allowTypeWidening = allowTypeWidening(deltaTable))
}

/**
Expand Down Expand Up @@ -984,7 +1008,10 @@ class DeltaAnalysis(session: SparkSession)
}
val specifiedTargetAttrs = targetAttrs.filter(col => userSpecifiedNames.contains(col.name))
!SchemaUtils.isReadCompatible(
specifiedTargetAttrs.toStructType.asNullable, query.output.toStructType)
specifiedTargetAttrs.toStructType.asNullable,
query.output.toStructType,
allowTypeWidening = allowTypeWidening(deltaTable)
)
}

// Get cast operation for the level of strictness in the schema a user asked for
Expand Down Expand Up @@ -1014,7 +1041,8 @@ class DeltaAnalysis(session: SparkSession)
tableName: String,
parent: NamedExpression,
source: StructType,
target: StructType): NamedExpression = {
target: StructType,
allowTypeWidening: Boolean): NamedExpression = {
if (source.length < target.length) {
throw DeltaErrors.notEnoughColumnsInInsert(
tableName, source.length, target.length, Some(parent.qualifiedName))
Expand All @@ -1025,12 +1053,20 @@ class DeltaAnalysis(session: SparkSession)
case t: StructType =>
val subField = Alias(GetStructField(parent, i, Option(name)), target(i).name)(
explicitMetadata = Option(metadata))
addCastsToStructs(tableName, subField, nested, t)
addCastsToStructs(tableName, subField, nested, t, allowTypeWidening)
case o =>
val field = parent.qualifiedName + "." + name
val targetName = parent.qualifiedName + "." + target(i).name
throw DeltaErrors.cannotInsertIntoColumn(tableName, field, targetName, o.simpleString)
}

case (StructField(name, dt: AtomicType, _, _), i) if i < target.length && allowTypeWidening &&
TypeWidening.isTypeChangeSupportedForSchemaEvolution(
target(i).dataType.asInstanceOf[AtomicType], dt) =>
val targetAttr = target(i)
Alias(
GetStructField(parent, i, Option(name)),
targetAttr.name)(explicitMetadata = Option(targetAttr.metadata))
case (other, i) if i < target.length =>
val targetAttr = target(i)
Alias(
Expand All @@ -1054,9 +1090,11 @@ class DeltaAnalysis(session: SparkSession)
parent: NamedExpression,
source: StructType,
target: StructType,
sourceNullable: Boolean): Expression = {
sourceNullable: Boolean,
allowTypeWidening: Boolean): Expression = {
val structConverter: (Expression, Expression) => Expression = (_, i) =>
addCastsToStructs(tableName, Alias(GetArrayItem(parent, i), i.toString)(), source, target)
addCastsToStructs(
tableName, Alias(GetArrayItem(parent, i), i.toString)(), source, target, allowTypeWidening)
val transformLambdaFunc = {
val elementVar = NamedLambdaVariable("elementVar", source, sourceNullable)
val indexVar = NamedLambdaVariable("indexVar", IntegerType, false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,8 @@ def normalizeColumnNamesInDataType(
* new schema of a Delta table can be used with a previously analyzed LogicalPlan. Our
* rules are to return false if:
* - Dropping any column that was present in the existing schema, if not allowMissingColumns
* - Any change of datatype
* - Any change of datatype, if not allowTypeWidening. Any non-widening change of datatype
* otherwise.
* - Change of partition columns. Although analyzed LogicalPlan is not changed,
* physical structure of data is changed and thus is considered not read compatible.
* - If `forbidTightenNullability` = true:
Expand All @@ -373,6 +374,7 @@ def normalizeColumnNamesInDataType(
readSchema: StructType,
forbidTightenNullability: Boolean = false,
allowMissingColumns: Boolean = false,
allowTypeWidening: Boolean = false,
newPartitionColumns: Seq[String] = Seq.empty,
oldPartitionColumns: Seq[String] = Seq.empty): Boolean = {

Expand All @@ -387,7 +389,7 @@ def normalizeColumnNamesInDataType(
def isDatatypeReadCompatible(existing: DataType, newtype: DataType): Boolean = {
(existing, newtype) match {
case (e: StructType, n: StructType) =>
isReadCompatible(e, n, forbidTightenNullability)
isReadCompatible(e, n, forbidTightenNullability, allowTypeWidening = allowTypeWidening)
case (e: ArrayType, n: ArrayType) =>
// if existing elements are non-nullable, so should be the new element
isNullabilityCompatible(e.containsNull, n.containsNull) &&
Expand All @@ -397,6 +399,8 @@ def normalizeColumnNamesInDataType(
isNullabilityCompatible(e.valueContainsNull, n.valueContainsNull) &&
isDatatypeReadCompatible(e.keyType, n.keyType) &&
isDatatypeReadCompatible(e.valueType, n.valueType)
case (e: AtomicType, n: AtomicType) if allowTypeWidening =>
TypeWidening.isTypeChangeSupportedForSchemaEvolution(e, n)
case (a, b) => a == b
}
}
Expand Down
Loading

0 comments on commit 5efbef6

Please sign in to comment.