diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 211b5a05eb70c..0268e97616606 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -356,25 +356,30 @@ case class CoalescedHashPartitioning(from: HashPartitioning, partitions: Seq[Coa case class KeyGroupedPartitioning( expressions: Seq[Expression], numPartitions: Int, - partitionValues: Seq[InternalRow] = Seq.empty) extends Partitioning { + partitionValues: Seq[InternalRow] = Seq.empty, + isPartiallyClustered: Boolean = false) extends Partitioning { + // See SPARK-55848. We must check ClusteredDistribution BEFORE delegating to + // super.satisfies0(), because the default satisfies0() also matches + // ClusteredDistribution and returns true, which would short-circuit the + // isPartiallyClustered guard. override def satisfies0(required: Distribution): Boolean = { - super.satisfies0(required) || { - required match { - case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) => - if (requireAllClusterKeys) { - // Checks whether this partitioning is partitioned on exactly same clustering keys of - // `ClusteredDistribution`. - c.areAllClusterKeysMatched(expressions) - } else { - // We'll need to find leaf attributes from the partition expressions first. - val attributes = expressions.flatMap(_.collectLeaves()) - attributes.forall(x => requiredClustering.exists(_.semanticEquals(x))) - } - - case _ => + required match { + case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) => + if (isPartiallyClustered) { false - } + } else if (requireAllClusterKeys) { + // Checks whether this partitioning is partitioned on exactly same clustering keys of + // `ClusteredDistribution`. + c.areAllClusterKeysMatched(expressions) + } else { + // We'll need to find leaf attributes from the partition expressions first. + val attributes = expressions.flatMap(_.collectLeaves()) + attributes.forall(x => requiredClustering.exists(_.semanticEquals(x))) + } + + case _ => + super.satisfies0(required) } } @@ -744,7 +749,10 @@ case class KeyGroupedShuffleSpec( // transform functions. // 4. the partition values from both sides are following the same order. case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution) => - distribution.clustering.length == otherDistribution.clustering.length && + // SPARK-55848: partially-clustered partitioning is not compatible for SPJ + !partitioning.isPartiallyClustered && + !otherPartitioning.isPartiallyClustered && + distribution.clustering.length == otherDistribution.clustering.length && numPartitions == other.numPartitions && areKeysCompatible(otherSpec) && partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall { case (left, right) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index 2a3a5cdeb82b8..2f262f5e13746 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -121,7 +121,8 @@ case class BatchScanExec( val newPartValues = spjParams.commonPartitionValues.get.flatMap { case (partValue, numSplits) => Seq.fill(numSplits)(partValue) } - k.copy(numPartitions = newPartValues.length, partitionValues = newPartValues) + k.copy(numPartitions = newPartValues.length, partitionValues = newPartValues, + isPartiallyClustered = spjParams.applyPartialClustering) case p => p } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 4546a5b1d2708..fd9fbe674ce27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -288,12 +288,12 @@ case class EnsureRequirements( reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys) .orElse(reorderJoinKeysRecursively( leftKeys, rightKeys, leftPartitioning, None)) - case (Some(KeyGroupedPartitioning(clustering, _, _)), _) => + case (Some(KeyGroupedPartitioning(clustering, _, _, _)), _) => val leafExprs = clustering.flatMap(_.collectLeaves()) reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, leftKeys) .orElse(reorderJoinKeysRecursively( leftKeys, rightKeys, None, rightPartitioning)) - case (_, Some(KeyGroupedPartitioning(clustering, _, _))) => + case (_, Some(KeyGroupedPartitioning(clustering, _, _, _))) => val leafExprs = clustering.flatMap(_.collectLeaves()) reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, rightKeys) .orElse(reorderJoinKeysRecursively( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala index f4317e632761c..3b973a4ca6c1d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala @@ -51,9 +51,10 @@ abstract class DistributionAndOrderingSuiteBase plan: QueryPlan[T]): Partitioning = partitioning match { case HashPartitioning(exprs, numPartitions) => HashPartitioning(exprs.map(resolveAttrs(_, plan)), numPartitions) - case KeyGroupedPartitioning(clustering, numPartitions, partitionValues) => + case KeyGroupedPartitioning(clustering, numPartitions, partitionValues, + isPartiallyClustered) => KeyGroupedPartitioning(clustering.map(resolveAttrs(_, plan)), numPartitions, - partitionValues) + partitionValues, isPartiallyClustered) case PartitioningCollection(partitionings) => PartitioningCollection(partitionings.map(resolvePartitioning(_, plan))) case RangePartitioning(ordering, numPartitions) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 7203cedf3ea7f..4819c210da15e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.connector.catalog.functions._ import org.apache.spark.sql.connector.distributions.Distributions import org.apache.spark.sql.connector.expressions._ import org.apache.spark.sql.connector.expressions.Expressions._ -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan} import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -293,6 +293,12 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { }) } + private def collectAllShuffles(plan: SparkPlan): Seq[ShuffleExchangeExec] = { + collect(plan) { + case s: ShuffleExchangeExec => s + } + } + private def collectScans(plan: SparkPlan): Seq[BatchScanExec] = { collect(plan) { case s: BatchScanExec => s } } @@ -1259,4 +1265,151 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } assert(metrics("number of rows read") == "2") } + + test("SPARK-55848: dropDuplicates after SPJ with partial clustering") { + val items_partitions = Array(identity("id")) + createTable(items, items_schema, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " + + "(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + val purchases_partitions = Array(identity("item_id")) + createTable(purchases, purchases_schema, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(1, 50.0, cast('2020-01-02' as timestamp)), " + + "(2, 11.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp))") + + withSQLConf( + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> true.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> true.toString) { + // dropDuplicates on the join key after a partially-clustered SPJ must still + // produce the correct number of distinct ids. Before the fix, the + // partially-clustered partitioning was incorrectly treated as satisfying + // ClusteredDistribution, so EnsureRequirements did not insert an Exchange + // before the dedup, leading to duplicate rows. + val df = sql( + s""" + |SELECT DISTINCT i.id + |FROM testcat.ns.$items i + |JOIN testcat.ns.$purchases p ON i.id = p.item_id + |""".stripMargin) + checkAnswer(df, Seq(Row(1), Row(2), Row(3))) + + val allShuffles = collectAllShuffles(df.queryExecution.executedPlan) + assert(allShuffles.nonEmpty, + "should contain a shuffle for the post-join dedup with partial clustering") + + val scans = collectScans(df.queryExecution.executedPlan) + assert(scans.exists(_.outputPartitioning match { + case kgp: physical.KeyGroupedPartitioning => kgp.isPartiallyClustered + case _ => false + }), "at least one BatchScanExec should have partially-clustered KeyGroupedPartitioning") + } + } + + test("SPARK-55848: Window dedup after SPJ with partial clustering") { + val items_partitions = Array(identity("id")) + createTable(items, items_schema, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " + + "(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + val purchases_partitions = Array(identity("item_id")) + createTable(purchases, purchases_schema, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(1, 50.0, cast('2020-01-02' as timestamp)), " + + "(2, 11.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp))") + + withSQLConf( + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> true.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> true.toString) { + // Use ROW_NUMBER() OVER to dedup joined rows per id after a partially-clustered + // SPJ. The WINDOW operator requires ClusteredDistribution on i.id; with partial + // clustering the plan must insert a shuffle so that the window + // produces exactly one row per id. + val df = sql( + s""" + |SELECT id, price FROM ( + | SELECT i.id, i.price, + | ROW_NUMBER() OVER (PARTITION BY i.id ORDER BY i.price DESC) AS rn + | FROM testcat.ns.$items i + | JOIN testcat.ns.$purchases p ON i.id = p.item_id + |) t WHERE rn = 1 + |""".stripMargin) + checkAnswer(df, Seq(Row(1, 41.0f), Row(2, 10.0f), Row(3, 15.5f))) + + val allShuffles = collectAllShuffles(df.queryExecution.executedPlan) + assert(allShuffles.nonEmpty, + "should contain a shuffle for the post-join window with partial clustering") + + val scans = collectScans(df.queryExecution.executedPlan) + assert(scans.exists(_.outputPartitioning match { + case kgp: physical.KeyGroupedPartitioning => kgp.isPartiallyClustered + case _ => false + }), "at least one BatchScanExec should have partially-clustered KeyGroupedPartitioning") + } + } + + test("SPARK-55848: checkpointed partially-clustered join with dedup") { + withTempDir { dir => + spark.sparkContext.setCheckpointDir(dir.getPath) + val items_partitions = Array(identity("id")) + createTable(items, items_schema, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " + + "(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + val purchases_partitions = Array(identity("item_id")) + createTable(purchases, purchases_schema, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(1, 50.0, cast('2020-01-02' as timestamp)), " + + "(2, 11.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp))") + + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> true.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> true.toString) { + // Checkpoint the JOIN result (not the scan) so the plan behind the + // checkpoint carries partially-clustered KeyGroupedPartitioning. + // The dedup on top must still insert an Exchange because the + // isPartiallyClustered flag causes satisfies0()=false for + // ClusteredDistribution. + val joinedDf = spark.sql( + s"""SELECT i.id, i.name, i.price + |FROM testcat.ns.$items i + |JOIN testcat.ns.$purchases p ON i.id = p.item_id""".stripMargin) + val checkpointedDf = joinedDf.checkpoint() + val df = checkpointedDf.select("id").distinct() + + checkAnswer(df, Seq(Row(1), Row(2), Row(3))) + + val allShuffles = collectAllShuffles(df.queryExecution.executedPlan) + assert(allShuffles.nonEmpty, + "should contain a shuffle for the dedup after checkpointed " + + "partially-clustered join") + + val rddScans = collect(df.queryExecution.executedPlan) { + case r: RDDScanExec => r + } + assert(rddScans.exists(_.outputPartitioning match { + case kgp: physical.KeyGroupedPartitioning => kgp.isPartiallyClustered + case _ => false + }), "checkpoint (RDDScanExec) should have " + + "partially-clustered KeyGroupedPartitioning") + } + } + } }