Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -356,25 +356,30 @@ case class CoalescedHashPartitioning(from: HashPartitioning, partitions: Seq[Coa
case class KeyGroupedPartitioning(
expressions: Seq[Expression],
numPartitions: Int,
partitionValues: Seq[InternalRow] = Seq.empty) extends Partitioning {
partitionValues: Seq[InternalRow] = Seq.empty,
isPartiallyClustered: Boolean = false) extends Partitioning {

// See SPARK-55848. We must check ClusteredDistribution BEFORE delegating to
// super.satisfies0(), because the default satisfies0() also matches
// ClusteredDistribution and returns true, which would short-circuit the
// isPartiallyClustered guard.
override def satisfies0(required: Distribution): Boolean = {
super.satisfies0(required) || {
required match {
case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) =>
if (requireAllClusterKeys) {
// Checks whether this partitioning is partitioned on exactly same clustering keys of
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as #54851 (comment).

// `ClusteredDistribution`.
c.areAllClusterKeysMatched(expressions)
} else {
// We'll need to find leaf attributes from the partition expressions first.
val attributes = expressions.flatMap(_.collectLeaves())
attributes.forall(x => requiredClustering.exists(_.semanticEquals(x)))
}

case _ =>
required match {
case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) =>
if (isPartiallyClustered) {
false
}
} else if (requireAllClusterKeys) {
// Checks whether this partitioning is partitioned on exactly same clustering keys of
// `ClusteredDistribution`.
c.areAllClusterKeysMatched(expressions)
} else {
// We'll need to find leaf attributes from the partition expressions first.
val attributes = expressions.flatMap(_.collectLeaves())
attributes.forall(x => requiredClustering.exists(_.semanticEquals(x)))
}

case _ =>
super.satisfies0(required)
}
}

Expand Down Expand Up @@ -744,7 +749,10 @@ case class KeyGroupedShuffleSpec(
// transform functions.
// 4. the partition values from both sides are following the same order.
case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution) =>
distribution.clustering.length == otherDistribution.clustering.length &&
// SPARK-55848: partially-clustered partitioning is not compatible for SPJ
!partitioning.isPartiallyClustered &&
!otherPartitioning.isPartiallyClustered &&
distribution.clustering.length == otherDistribution.clustering.length &&
numPartitions == other.numPartitions && areKeysCompatible(otherSpec) &&
partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall {
case (left, right) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ case class BatchScanExec(
val newPartValues = spjParams.commonPartitionValues.get.flatMap {
case (partValue, numSplits) => Seq.fill(numSplits)(partValue)
}
k.copy(numPartitions = newPartValues.length, partitionValues = newPartValues)
k.copy(numPartitions = newPartValues.length, partitionValues = newPartValues,
isPartiallyClustered = spjParams.applyPartialClustering)
case p => p
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,12 @@ case class EnsureRequirements(
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys)
.orElse(reorderJoinKeysRecursively(
leftKeys, rightKeys, leftPartitioning, None))
case (Some(KeyGroupedPartitioning(clustering, _, _)), _) =>
case (Some(KeyGroupedPartitioning(clustering, _, _, _)), _) =>
val leafExprs = clustering.flatMap(_.collectLeaves())
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, leftKeys)
.orElse(reorderJoinKeysRecursively(
leftKeys, rightKeys, None, rightPartitioning))
case (_, Some(KeyGroupedPartitioning(clustering, _, _))) =>
case (_, Some(KeyGroupedPartitioning(clustering, _, _, _))) =>
val leafExprs = clustering.flatMap(_.collectLeaves())
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, rightKeys)
.orElse(reorderJoinKeysRecursively(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ abstract class DistributionAndOrderingSuiteBase
plan: QueryPlan[T]): Partitioning = partitioning match {
case HashPartitioning(exprs, numPartitions) =>
HashPartitioning(exprs.map(resolveAttrs(_, plan)), numPartitions)
case KeyGroupedPartitioning(clustering, numPartitions, partitionValues) =>
case KeyGroupedPartitioning(clustering, numPartitions, partitionValues,
isPartiallyClustered) =>
KeyGroupedPartitioning(clustering.map(resolveAttrs(_, plan)), numPartitions,
partitionValues)
partitionValues, isPartiallyClustered)
case PartitioningCollection(partitionings) =>
PartitioningCollection(partitionings.map(resolvePartitioning(_, plan)))
case RangePartitioning(ordering, numPartitions) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.connector.catalog.functions._
import org.apache.spark.sql.connector.distributions.Distributions
import org.apache.spark.sql.connector.expressions._
import org.apache.spark.sql.connector.expressions.Expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan}
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
Expand Down Expand Up @@ -293,6 +293,12 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
})
}

private def collectAllShuffles(plan: SparkPlan): Seq[ShuffleExchangeExec] = {
collect(plan) {
case s: ShuffleExchangeExec => s
}
}

private def collectScans(plan: SparkPlan): Seq[BatchScanExec] = {
collect(plan) { case s: BatchScanExec => s }
}
Expand Down Expand Up @@ -1259,4 +1265,151 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
}
assert(metrics("number of rows read") == "2")
}

test("SPARK-55848: dropDuplicates after SPJ with partial clustering") {
val items_partitions = Array(identity("id"))
createTable(items, items_schema, items_partitions)
sql(s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " +
"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))")

val purchases_partitions = Array(identity("item_id"))
createTable(purchases, purchases_schema, purchases_partitions)
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
"(1, 50.0, cast('2020-01-02' as timestamp)), " +
"(2, 11.0, cast('2020-01-01' as timestamp)), " +
"(3, 19.5, cast('2020-02-01' as timestamp))")

withSQLConf(
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> true.toString,
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> true.toString) {
// dropDuplicates on the join key after a partially-clustered SPJ must still
// produce the correct number of distinct ids. Before the fix, the
// partially-clustered partitioning was incorrectly treated as satisfying
// ClusteredDistribution, so EnsureRequirements did not insert an Exchange
// before the dedup, leading to duplicate rows.
val df = sql(
s"""
|SELECT DISTINCT i.id
|FROM testcat.ns.$items i
|JOIN testcat.ns.$purchases p ON i.id = p.item_id
|""".stripMargin)
checkAnswer(df, Seq(Row(1), Row(2), Row(3)))

val allShuffles = collectAllShuffles(df.queryExecution.executedPlan)
assert(allShuffles.nonEmpty,
"should contain a shuffle for the post-join dedup with partial clustering")

val scans = collectScans(df.queryExecution.executedPlan)
assert(scans.exists(_.outputPartitioning match {
case kgp: physical.KeyGroupedPartitioning => kgp.isPartiallyClustered
case _ => false
}), "at least one BatchScanExec should have partially-clustered KeyGroupedPartitioning")
}
}

test("SPARK-55848: Window dedup after SPJ with partial clustering") {
val items_partitions = Array(identity("id"))
createTable(items, items_schema, items_partitions)
sql(s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " +
"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))")

val purchases_partitions = Array(identity("item_id"))
createTable(purchases, purchases_schema, purchases_partitions)
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
"(1, 50.0, cast('2020-01-02' as timestamp)), " +
"(2, 11.0, cast('2020-01-01' as timestamp)), " +
"(3, 19.5, cast('2020-02-01' as timestamp))")

withSQLConf(
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> true.toString,
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> true.toString) {
// Use ROW_NUMBER() OVER to dedup joined rows per id after a partially-clustered
// SPJ. The WINDOW operator requires ClusteredDistribution on i.id; with partial
// clustering the plan must insert a shuffle so that the window
// produces exactly one row per id.
val df = sql(
s"""
|SELECT id, price FROM (
| SELECT i.id, i.price,
| ROW_NUMBER() OVER (PARTITION BY i.id ORDER BY i.price DESC) AS rn
| FROM testcat.ns.$items i
| JOIN testcat.ns.$purchases p ON i.id = p.item_id
|) t WHERE rn = 1
|""".stripMargin)
checkAnswer(df, Seq(Row(1, 41.0f), Row(2, 10.0f), Row(3, 15.5f)))

val allShuffles = collectAllShuffles(df.queryExecution.executedPlan)
assert(allShuffles.nonEmpty,
"should contain a shuffle for the post-join window with partial clustering")

val scans = collectScans(df.queryExecution.executedPlan)
assert(scans.exists(_.outputPartitioning match {
case kgp: physical.KeyGroupedPartitioning => kgp.isPartiallyClustered
case _ => false
}), "at least one BatchScanExec should have partially-clustered KeyGroupedPartitioning")
}
}

test("SPARK-55848: checkpointed partially-clustered join with dedup") {
withTempDir { dir =>
spark.sparkContext.setCheckpointDir(dir.getPath)
val items_partitions = Array(identity("id"))
createTable(items, items_schema, items_partitions)
sql(s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " +
"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))")

val purchases_partitions = Array(identity("item_id"))
createTable(purchases, purchases_schema, purchases_partitions)
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
"(1, 50.0, cast('2020-01-02' as timestamp)), " +
"(2, 11.0, cast('2020-01-01' as timestamp)), " +
"(3, 19.5, cast('2020-02-01' as timestamp))")

withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> true.toString,
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> true.toString) {
// Checkpoint the JOIN result (not the scan) so the plan behind the
// checkpoint carries partially-clustered KeyGroupedPartitioning.
// The dedup on top must still insert an Exchange because the
// isPartiallyClustered flag causes satisfies0()=false for
// ClusteredDistribution.
val joinedDf = spark.sql(
s"""SELECT i.id, i.name, i.price
|FROM testcat.ns.$items i
|JOIN testcat.ns.$purchases p ON i.id = p.item_id""".stripMargin)
val checkpointedDf = joinedDf.checkpoint()
val df = checkpointedDf.select("id").distinct()

checkAnswer(df, Seq(Row(1), Row(2), Row(3)))

val allShuffles = collectAllShuffles(df.queryExecution.executedPlan)
assert(allShuffles.nonEmpty,
"should contain a shuffle for the dedup after checkpointed " +
"partially-clustered join")

val rddScans = collect(df.queryExecution.executedPlan) {
case r: RDDScanExec => r
}
assert(rddScans.exists(_.outputPartitioning match {
case kgp: physical.KeyGroupedPartitioning => kgp.isPartiallyClustered
case _ => false
}), "checkpoint (RDDScanExec) should have " +
"partially-clustered KeyGroupedPartitioning")
}
}
}
}