Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.spark.sql.connector.catalog.functions;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.types.DataType;

/**
* A 'reducer' for output of user-defined functions.
Expand All @@ -39,4 +40,15 @@
@Evolving
public interface Reducer<I, O> {
O reduce(I arg);

/**
* Returns the {@link DataType data type} of values produced by this reducer.
*
* As a reducer doesn't know the result {@link DataType data type} of the reduced transform
* function, for compatibility reasons it can return null to signal it doesn't change the type of
* partition keys when the keys are reduced.
*
* @return a data type for values produced by this function.
*/
default DataType resultType() { return null; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -464,11 +464,15 @@ case class KeyedPartitioning(
KeyedPartitioning.projectKeys(partitionKeys, expressionDataTypes, positions)

/**
* Reduces this partitioning's partition keys by applying the given reducers.
* Reduces this partitioning's partition keys by applying the given reducers and use the provided
* types for comparison.
* Returns the distinct reduced keys.
*/
def reduceKeys(reducers: Seq[Option[Reducer[_, _]]]): Seq[InternalRowComparableWrapper] =
KeyedPartitioning.reduceKeys(partitionKeys, expressionDataTypes, reducers).distinct
def reduceKeys(
reducers: Seq[Option[Reducer[_, _]]],
reducedDataTypes: Seq[DataType]): Seq[InternalRowComparableWrapper] =
KeyedPartitioning.reduceKeys(partitionKeys, expressionDataTypes, reducers, reducedDataTypes)
.distinct

override def satisfies0(required: Distribution): Boolean = {
nonGroupedSatisfies(required) || groupedSatisfies(required)
Expand Down Expand Up @@ -581,14 +585,28 @@ object KeyedPartitioning {
}

/**
* Reduces a sequence of partition keys by applying reducers to each position.
* Reduces a sequence of data types by applying reducers to each position.
*/
def reduceTypes(
dataTypes: Seq[DataType],
reducers: Seq[Option[Reducer[_, _]]]): Seq[DataType] = {
dataTypes.zip(reducers).map {
case (t, Some(reducer: Reducer[Any, Any])) => Option(reducer.resultType()).getOrElse(t)
case (t, _) => t
}
}

/**
* Reduces a sequence of partition keys by applying reducers to each position and using the
* provided types for comparison.
*/
def reduceKeys(
keys: Seq[InternalRowComparableWrapper],
dataTypes: Seq[DataType],
reducers: Seq[Option[Reducer[_, _]]]): Seq[InternalRowComparableWrapper] = {
reducers: Seq[Option[Reducer[_, _]]],
reducedDataTypes: Seq[DataType]): Seq[InternalRowComparableWrapper] = {
val comparableKeyWrapperFactory =
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(dataTypes)
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(reducedDataTypes)
keys.map { key =>
val keyValues = key.row.toSeq(dataTypes)
val reducedKey = keyValues.zip(reducers).map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2133,6 +2133,17 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val V2_BUCKETING_ALLOW_INCOMPATIBLE_TRANSFORM_TYPES =
buildConf("spark.sql.sources.v2.bucketing.allowIncompatibleTransformTypes.enabled")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we can set this this configuration false for some cases in the future, @peter-toth ? I'm a little confused when it makes senses that we are going to disallow incompatible transform types.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good question and I too was thinking about it. I feel we should not compare different logical types due to their different semantical meanings, but seemingly this is what we do currently in some cases, so we should probably keep the behavior for now. I think in a future Spark release we can change this config to make sure a comparison makes sense.

.doc("Whether to allow storage-partition join where the left and right partition " +
"transforms are reduced to differing logical types and in that case use the left reduced " +
"logical types for comparison. This config requires " +
s"${V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key} to be enabled.")
.version("4.2.0")
.withBindingPolicy(ConfigBindingPolicy.SESSION)
.booleanConf
.createWithDefault(true)

val V2_BUCKETING_PARTITION_FILTER_ENABLED =
buildConf("spark.sql.sources.v2.bucketing.partition.filter.enabled")
.doc(s"Whether to filter partitions when running storage-partition join. " +
Expand Down Expand Up @@ -7692,6 +7703,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def v2BucketingAllowCompatibleTransforms: Boolean =
getConf(SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS)

def v2BucketingAllowIncompatibleTransformTypes: Boolean =
getConf(SQLConf.V2_BUCKETING_ALLOW_INCOMPATIBLE_TRANSFORM_TYPES)

def v2BucketingAllowSorting: Boolean =
getConf(SQLConf.V2_BUCKETING_SORTING_ENABLED)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,10 @@ abstract class InMemoryBaseTable(
case YearsTransform(ref) =>
extractor(ref.fieldNames, cleanedSchema, row) match {
case (days: Int, DateType) =>
ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, DateTimeUtils.daysToLocalDate(days))
ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, DateTimeUtils.daysToLocalDate(days)).toInt
case (micros: Long, TimestampType) =>
val localDate = DateTimeUtils.microsToInstant(micros).atZone(UTC).toLocalDate
ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate)
ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate).toInt
case (v, t) =>
throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)")
}
Expand All @@ -225,7 +225,7 @@ abstract class InMemoryBaseTable(
case (days, DateType) =>
days
case (micros: Long, TimestampType) =>
ChronoUnit.DAYS.between(Instant.EPOCH, DateTimeUtils.microsToInstant(micros))
ChronoUnit.DAYS.between(Instant.EPOCH, DateTimeUtils.microsToInstant(micros)).toInt
case (v, t) =>
throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,21 @@ case class GroupPartitionsExec(
)(keyedPartitioning.projectKeys)

// Reduce keys if reducers are specified
val reducedKeys = reducers.fold(projectedKeys)(
KeyedPartitioning.reduceKeys(projectedKeys, projectedDataTypes, _))
val (reducedDataTypes, reducedKeys) = reducers match {
case Some(reducers) =>
val reducedDataTypes = KeyedPartitioning.reduceTypes(projectedDataTypes, reducers)
val reducedKeys = KeyedPartitioning.reduceKeys(projectedKeys, projectedDataTypes, reducers,
reducedDataTypes)
(reducedDataTypes, reducedKeys)
case _ => (projectedDataTypes, projectedKeys)
}

val keyToPartitionIndices = reducedKeys.zipWithIndex.groupMap(_._1)(_._2)

if (expectedPartitionKeys.isDefined) {
alignToExpectedKeys(keyToPartitionIndices)
} else {
(groupAndSortByKeys(keyToPartitionIndices, projectedDataTypes), true)
(groupAndSortByKeys(keyToPartitionIndices, reducedDataTypes), true)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,20 @@ package org.apache.spark.sql.execution.exchange
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.SparkException
import org.apache.spark.internal.{LogKeys}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.types.PhysicalDataType
import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper
import org.apache.spark.sql.connector.catalog.functions.Reducer
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources.v2.GroupPartitionsExec
import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.DataType

/**
* Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]]
Expand Down Expand Up @@ -509,11 +512,27 @@ case class EnsureRequirements(
// in case of compatible but not identical partition expressions, we apply 'reduce'
// transforms to group one side's partitions as well as the common partition values
val leftReducers = leftSpec.reducers(rightSpec)
val leftReducedKeys =
leftReducers.fold(leftPartitioning.partitionKeys)(leftPartitioning.reduceKeys)
val rightReducers = rightSpec.reducers(leftSpec)
val rightReducedKeys =
rightReducers.fold(rightPartitioning.partitionKeys)(rightPartitioning.reduceKeys)
val leftReducedDataTypes = leftReducers.fold(leftPartitioning.expressionDataTypes)(
KeyedPartitioning.reduceTypes(leftPartitioning.expressionDataTypes, _))
val rightReducedDataTypes = rightReducers.fold(rightPartitioning.expressionDataTypes)(
KeyedPartitioning.reduceTypes(rightPartitioning.expressionDataTypes, _))
if (leftReducedDataTypes != rightReducedDataTypes && (
!conf.v2BucketingAllowIncompatibleTransformTypes ||
leftReducedDataTypes.map(PhysicalDataType(_)) !=
rightReducedDataTypes.map(PhysicalDataType(_)))) {
throw new SparkException("Storage-partition join partition transforms produced " +
s"incompatible reduced types, left: $leftReducedDataTypes, right: " +
s"$rightReducedDataTypes")
}
val commonDataTypes = leftReducedDataTypes
val leftReducedKeys = leftReducers.fold(leftPartitioning.partitionKeys)(
leftPartitioning.reduceKeys(_, commonDataTypes))
// As we use left side reduced types as common types for comparison, the right side
// partitions keys might need a new comparable wrapper (depending on the legacy flag)
val rightReducedKeys = rightReducers.fold(
rewrapKeys(rightPartitioning.partitionKeys, rightReducedDataTypes, commonDataTypes))(
rightPartitioning.reduceKeys(_, commonDataTypes))

// merge values on both sides
var mergedPartitionKeys =
Expand Down Expand Up @@ -628,10 +647,17 @@ case class EnsureRequirements(
}
}

val leftMergedPartitionKeys = mergedPartitionKeys
// As we used left side reduced types as common types for comparison, the merged partition
// keys that we push doww to the right side might need a new comparable wrapper (depending
// on the legacy flag)
val rightMergedPartitionKeys =
rewrapKeyMap(mergedPartitionKeys, commonDataTypes, rightReducedDataTypes)

// Now we need to push-down the common partition information to the `GroupPartitionsExec`s.
newLeft = applyGroupPartitions(left, leftSpec.joinKeyPositions, mergedPartitionKeys,
newLeft = applyGroupPartitions(left, leftSpec.joinKeyPositions, leftMergedPartitionKeys,
leftReducers, distributePartitions = applyPartialClustering && !replicateLeftSide)
newRight = applyGroupPartitions(right, rightSpec.joinKeyPositions, mergedPartitionKeys,
newRight = applyGroupPartitions(right, rightSpec.joinKeyPositions, rightMergedPartitionKeys,
rightReducers, distributePartitions = applyPartialClustering && !replicateRightSide)
}
}
Expand All @@ -656,6 +682,32 @@ case class EnsureRequirements(
}
}

private def rewrapKeys(
keys: Seq[InternalRowComparableWrapper],
currentDataTypes: Seq[DataType],
expectedDataType: Seq[DataType]) = {
if (currentDataTypes != expectedDataType) {
val comparableKeyWrapperFactory =
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(expectedDataType)
keys.map(key => comparableKeyWrapperFactory(key.row))
} else {
keys
}
}

private def rewrapKeyMap(
keyMap: Seq[(InternalRowComparableWrapper, Int)],
currentDataTypes: Seq[DataType],
expectedDataType: Seq[DataType]) = {
if (currentDataTypes != expectedDataType) {
val comparableKeyWrapperFactory =
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(expectedDataType)
keyMap.map { case (key, numParts) => (comparableKeyWrapperFactory(key.row), numParts) }
} else {
keyMap
}
}

// Similar to `OptimizeSkewedJoin.canSplitRightSide`
private def canReplicateLeftSide(joinType: JoinType): Boolean = {
joinType == Inner || joinType == Cross || joinType == RightOuter
Expand Down
Loading