From 1b3f1ac32e1749bf0d73a86d18b005be1339241a Mon Sep 17 00:00:00 2001 From: Naveen Kumar Puppala Date: Mon, 16 Mar 2026 23:14:36 -0700 Subject: [PATCH] [SPARK-55848][SQL][3.5] Fix incorrect dedup results with SPJ partial clustering When SPJ partial clustering splits a partition across multiple tasks, post-join dedup operators (dropDuplicates, Window row_number) produce incorrect results because KeyGroupedPartitioning.satisfies0() incorrectly reports satisfaction of ClusteredDistribution. This fix adds an isPartiallyClustered flag to KeyGroupedPartitioning and restructures satisfies0() to check ClusteredDistribution first, returning false when partially clustered. EnsureRequirements then inserts the necessary Exchange. Plain SPJ joins without dedup are unaffected. Closes #54378 --- .../plans/physical/partitioning.scala | 42 +++-- .../datasources/v2/BatchScanExec.scala | 3 +- .../exchange/EnsureRequirements.scala | 4 +- .../DistributionAndOrderingSuiteBase.scala | 5 +- .../KeyGroupedPartitioningSuite.scala | 155 +++++++++++++++++- 5 files changed, 186 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 211b5a05eb70c..0268e97616606 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -356,25 +356,30 @@ case class CoalescedHashPartitioning(from: HashPartitioning, partitions: Seq[Coa case class KeyGroupedPartitioning( expressions: Seq[Expression], numPartitions: Int, - partitionValues: Seq[InternalRow] = Seq.empty) extends Partitioning { + partitionValues: Seq[InternalRow] = Seq.empty, + isPartiallyClustered: Boolean = false) extends Partitioning { + // See SPARK-55848. We must check ClusteredDistribution BEFORE delegating to + // super.satisfies0(), because the default satisfies0() also matches + // ClusteredDistribution and returns true, which would short-circuit the + // isPartiallyClustered guard. override def satisfies0(required: Distribution): Boolean = { - super.satisfies0(required) || { - required match { - case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) => - if (requireAllClusterKeys) { - // Checks whether this partitioning is partitioned on exactly same clustering keys of - // `ClusteredDistribution`. - c.areAllClusterKeysMatched(expressions) - } else { - // We'll need to find leaf attributes from the partition expressions first. - val attributes = expressions.flatMap(_.collectLeaves()) - attributes.forall(x => requiredClustering.exists(_.semanticEquals(x))) - } - - case _ => + required match { + case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) => + if (isPartiallyClustered) { false - } + } else if (requireAllClusterKeys) { + // Checks whether this partitioning is partitioned on exactly same clustering keys of + // `ClusteredDistribution`. + c.areAllClusterKeysMatched(expressions) + } else { + // We'll need to find leaf attributes from the partition expressions first. + val attributes = expressions.flatMap(_.collectLeaves()) + attributes.forall(x => requiredClustering.exists(_.semanticEquals(x))) + } + + case _ => + super.satisfies0(required) } } @@ -744,7 +749,10 @@ case class KeyGroupedShuffleSpec( // transform functions. // 4. the partition values from both sides are following the same order. case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution) => - distribution.clustering.length == otherDistribution.clustering.length && + // SPARK-55848: partially-clustered partitioning is not compatible for SPJ + !partitioning.isPartiallyClustered && + !otherPartitioning.isPartiallyClustered && + distribution.clustering.length == otherDistribution.clustering.length && numPartitions == other.numPartitions && areKeysCompatible(otherSpec) && partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall { case (left, right) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index 2a3a5cdeb82b8..2f262f5e13746 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -121,7 +121,8 @@ case class BatchScanExec( val newPartValues = spjParams.commonPartitionValues.get.flatMap { case (partValue, numSplits) => Seq.fill(numSplits)(partValue) } - k.copy(numPartitions = newPartValues.length, partitionValues = newPartValues) + k.copy(numPartitions = newPartValues.length, partitionValues = newPartValues, + isPartiallyClustered = spjParams.applyPartialClustering) case p => p } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 4546a5b1d2708..fd9fbe674ce27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -288,12 +288,12 @@ case class EnsureRequirements( reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys) .orElse(reorderJoinKeysRecursively( leftKeys, rightKeys, leftPartitioning, None)) - case (Some(KeyGroupedPartitioning(clustering, _, _)), _) => + case (Some(KeyGroupedPartitioning(clustering, _, _, _)), _) => val leafExprs = clustering.flatMap(_.collectLeaves()) reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, leftKeys) .orElse(reorderJoinKeysRecursively( leftKeys, rightKeys, None, rightPartitioning)) - case (_, Some(KeyGroupedPartitioning(clustering, _, _))) => + case (_, Some(KeyGroupedPartitioning(clustering, _, _, _))) => val leafExprs = clustering.flatMap(_.collectLeaves()) reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, rightKeys) .orElse(reorderJoinKeysRecursively( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala index f4317e632761c..3b973a4ca6c1d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala @@ -51,9 +51,10 @@ abstract class DistributionAndOrderingSuiteBase plan: QueryPlan[T]): Partitioning = partitioning match { case HashPartitioning(exprs, numPartitions) => HashPartitioning(exprs.map(resolveAttrs(_, plan)), numPartitions) - case KeyGroupedPartitioning(clustering, numPartitions, partitionValues) => + case KeyGroupedPartitioning(clustering, numPartitions, partitionValues, + isPartiallyClustered) => KeyGroupedPartitioning(clustering.map(resolveAttrs(_, plan)), numPartitions, - partitionValues) + partitionValues, isPartiallyClustered) case PartitioningCollection(partitionings) => PartitioningCollection(partitionings.map(resolvePartitioning(_, plan))) case RangePartitioning(ordering, numPartitions) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 7203cedf3ea7f..4819c210da15e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.connector.catalog.functions._ import org.apache.spark.sql.connector.distributions.Distributions import org.apache.spark.sql.connector.expressions._ import org.apache.spark.sql.connector.expressions.Expressions._ -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan} import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -293,6 +293,12 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { }) } + private def collectAllShuffles(plan: SparkPlan): Seq[ShuffleExchangeExec] = { + collect(plan) { + case s: ShuffleExchangeExec => s + } + } + private def collectScans(plan: SparkPlan): Seq[BatchScanExec] = { collect(plan) { case s: BatchScanExec => s } } @@ -1259,4 +1265,151 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } assert(metrics("number of rows read") == "2") } + + test("SPARK-55848: dropDuplicates after SPJ with partial clustering") { + val items_partitions = Array(identity("id")) + createTable(items, items_schema, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " + + "(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + val purchases_partitions = Array(identity("item_id")) + createTable(purchases, purchases_schema, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(1, 50.0, cast('2020-01-02' as timestamp)), " + + "(2, 11.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp))") + + withSQLConf( + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> true.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> true.toString) { + // dropDuplicates on the join key after a partially-clustered SPJ must still + // produce the correct number of distinct ids. Before the fix, the + // partially-clustered partitioning was incorrectly treated as satisfying + // ClusteredDistribution, so EnsureRequirements did not insert an Exchange + // before the dedup, leading to duplicate rows. + val df = sql( + s""" + |SELECT DISTINCT i.id + |FROM testcat.ns.$items i + |JOIN testcat.ns.$purchases p ON i.id = p.item_id + |""".stripMargin) + checkAnswer(df, Seq(Row(1), Row(2), Row(3))) + + val allShuffles = collectAllShuffles(df.queryExecution.executedPlan) + assert(allShuffles.nonEmpty, + "should contain a shuffle for the post-join dedup with partial clustering") + + val scans = collectScans(df.queryExecution.executedPlan) + assert(scans.exists(_.outputPartitioning match { + case kgp: physical.KeyGroupedPartitioning => kgp.isPartiallyClustered + case _ => false + }), "at least one BatchScanExec should have partially-clustered KeyGroupedPartitioning") + } + } + + test("SPARK-55848: Window dedup after SPJ with partial clustering") { + val items_partitions = Array(identity("id")) + createTable(items, items_schema, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " + + "(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + val purchases_partitions = Array(identity("item_id")) + createTable(purchases, purchases_schema, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(1, 50.0, cast('2020-01-02' as timestamp)), " + + "(2, 11.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp))") + + withSQLConf( + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> true.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> true.toString) { + // Use ROW_NUMBER() OVER to dedup joined rows per id after a partially-clustered + // SPJ. The WINDOW operator requires ClusteredDistribution on i.id; with partial + // clustering the plan must insert a shuffle so that the window + // produces exactly one row per id. + val df = sql( + s""" + |SELECT id, price FROM ( + | SELECT i.id, i.price, + | ROW_NUMBER() OVER (PARTITION BY i.id ORDER BY i.price DESC) AS rn + | FROM testcat.ns.$items i + | JOIN testcat.ns.$purchases p ON i.id = p.item_id + |) t WHERE rn = 1 + |""".stripMargin) + checkAnswer(df, Seq(Row(1, 41.0f), Row(2, 10.0f), Row(3, 15.5f))) + + val allShuffles = collectAllShuffles(df.queryExecution.executedPlan) + assert(allShuffles.nonEmpty, + "should contain a shuffle for the post-join window with partial clustering") + + val scans = collectScans(df.queryExecution.executedPlan) + assert(scans.exists(_.outputPartitioning match { + case kgp: physical.KeyGroupedPartitioning => kgp.isPartiallyClustered + case _ => false + }), "at least one BatchScanExec should have partially-clustered KeyGroupedPartitioning") + } + } + + test("SPARK-55848: checkpointed partially-clustered join with dedup") { + withTempDir { dir => + spark.sparkContext.setCheckpointDir(dir.getPath) + val items_partitions = Array(identity("id")) + createTable(items, items_schema, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " + + "(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + val purchases_partitions = Array(identity("item_id")) + createTable(purchases, purchases_schema, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(1, 50.0, cast('2020-01-02' as timestamp)), " + + "(2, 11.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp))") + + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> true.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> true.toString) { + // Checkpoint the JOIN result (not the scan) so the plan behind the + // checkpoint carries partially-clustered KeyGroupedPartitioning. + // The dedup on top must still insert an Exchange because the + // isPartiallyClustered flag causes satisfies0()=false for + // ClusteredDistribution. + val joinedDf = spark.sql( + s"""SELECT i.id, i.name, i.price + |FROM testcat.ns.$items i + |JOIN testcat.ns.$purchases p ON i.id = p.item_id""".stripMargin) + val checkpointedDf = joinedDf.checkpoint() + val df = checkpointedDf.select("id").distinct() + + checkAnswer(df, Seq(Row(1), Row(2), Row(3))) + + val allShuffles = collectAllShuffles(df.queryExecution.executedPlan) + assert(allShuffles.nonEmpty, + "should contain a shuffle for the dedup after checkpointed " + + "partially-clustered join") + + val rddScans = collect(df.queryExecution.executedPlan) { + case r: RDDScanExec => r + } + assert(rddScans.exists(_.outputPartitioning match { + case kgp: physical.KeyGroupedPartitioning => kgp.isPartiallyClustered + case _ => false + }), "checkpoint (RDDScanExec) should have " + + "partially-clustered KeyGroupedPartitioning") + } + } + } }