From f5d7ed0a08f8f38427134a0ba03696c5456af07e Mon Sep 17 00:00:00 2001 From: Mihailo Timotic Date: Wed, 18 Mar 2026 09:35:32 +0000 Subject: [PATCH 1/3] [SPARK-XXXX][SQL] Add new single-pass resolver infrastructure from Databricks Runtime Port additional single-pass resolver components from Databricks Runtime to OSS Spark, expanding the resolver package with new operator and expression resolvers. This includes support for pivot/unpivot, higher-order functions, lambda functions, table-valued functions, extract values, grouping analytics, repartition-by-expression, name-parameterized queries, and logical plan diffing infrastructure. Co-authored-by: Isaac --- .../analysis/FunctionResolution.scala | 36 ++ .../ExpressionResolutionContext.scala | 28 +- .../resolver/ExpressionResolver.scala | 55 +- .../resolver/ExtractValueResolver.scala | 106 ++++ .../resolver/FunctionResolverUtils.scala | 157 ++++++ .../resolver/GroupingAnalyticsResolver.scala | 203 ++++++++ .../HigherOrderFunctionResolver.scala | 172 +++++++ ...dentifierFromUnresolvedNodeExtractor.scala | 44 ++ .../resolver/LambdaFunctionResolver.scala | 126 +++++ .../resolver/LogicalPlanDifference.scala | 223 ++++++++ .../NameParameterizedQueryResolver.scala | 173 +++++++ .../analysis/resolver/NameScope.scala | 14 + .../analysis/resolver/PivotResolver.scala | 170 ++++++ .../analysis/resolver/RecursiveCteState.scala | 46 ++ .../RepartitionByExpressionResolver.scala | 131 +++++ .../resolver/ResolverRunnerResult.scala | 39 ++ .../TableValuedFunctionResolver.scala | 80 +++ .../analysis/resolver/UnpivotResolver.scala | 278 ++++++++++ .../resolver/LogicalPlanDifferenceSuite.scala | 484 ++++++++++++++++++ 19 files changed, 2558 insertions(+), 7 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExtractValueResolver.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/FunctionResolverUtils.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/GroupingAnalyticsResolver.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/HigherOrderFunctionResolver.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/IdentifierFromUnresolvedNodeExtractor.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LambdaFunctionResolver.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LogicalPlanDifference.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/NameParameterizedQueryResolver.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/PivotResolver.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/RecursiveCteState.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/RepartitionByExpressionResolver.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverRunnerResult.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TableValuedFunctionResolver.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnpivotResolver.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/resolver/LogicalPlanDifferenceSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala index d293520c08cf6..cb8f8f0d529d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionResolution.scala @@ -236,6 +236,42 @@ class FunctionResolution( None } + private def resolveInternalFunction( + name: String, arguments: Seq[Expression]): Expression = { + val qualified = FunctionIdentifier( + name, Some(CatalogManager.SESSION_NAMESPACE), Some(CatalogManager.SYSTEM_CATALOG_NAME)) + if (FunctionRegistry.internal.functionExists(qualified)) { + FunctionRegistry.internal.lookupFunction(qualified, arguments) + } else { + FunctionRegistry.internal.lookupFunction(FunctionIdentifier(name), arguments) + } + } + + def resolveBuiltinOrTempFunction( + name: Seq[String], + arguments: Seq[Expression], + u: UnresolvedFunction): Option[Expression] = { + val expression = if (name.length == 1 && u.isInternal) { + Option(resolveInternalFunction(name.head, arguments)) + } else if (name.length == 1) { + v1SessionCatalog.resolveBuiltinOrTempFunction(name.head, arguments) + } else { + None + } + expression.map { func => + validateFunction(func, arguments.length, u) + } + } + + def resolveTableValuedFunction(u: UnresolvedTableValuedFunction): LogicalPlan = { + resolveTableFunction(u.name, u.functionArgs) + .getOrElse { + throw new NoSuchFunctionException( + db = u.name.dropRight(1).mkString("."), + func = u.name.last) + } + } + /** * Check if the arguments of a function are either resolved or a lambda function. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionContext.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionContext.scala index 6af1a585a231a..c4259ab001364 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionContext.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionContext.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.analysis.resolver -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression} /** * The [[ExpressionResolutionContext]] is a state that is propagated between the nodes of the @@ -45,6 +45,8 @@ import org.apache.spark.sql.catalyst.expressions.Expression * [[ExpressionResolutionContext]] has LCA in its subtree. * @param hasWindowExpressions A flag that highlights that a specific node corresponding to * [[ExpressionResolutionContext]] has [[WindowExpression]]s in its subtree. + * @param hasGeneratorExpressions A flag that highlights that a specific node corresponding to + * [[ExpressionResolutionContext]] has [[Generator]] expressions in its subtree. * @param shouldPreserveAlias A flag indicating whether we preserve the [[Alias]] e.g. if it is on * top of a [[Project.projectList]]. If it is `false`, extra [[Alias]]es have to be stripped * away. @@ -84,6 +86,15 @@ import org.apache.spark.sql.catalyst.expressions.Expression * @param resolvingUnresolvedAlias A flag indicating whether we are resolving a tree under an * [[UnresolvedAlias]]. This is needed in order to prevent alias collapsing before the name of * [[UnresolvedAlias]] above is computed. + * @param resolvingPivotAggregates A flag indicating whether we are resolving a tree under the + * [[Pivot.aggregates]]. This need for validation of those expressions. + * @param hasGroupingAnalyticsExpression A flag indicating whether a specific node corresponding to + * [[ExpressionResolutionContext]] has grouping expressions ([[Grouping]] or [[GroupingID]]) in + * its subtree. + * @param extractValueExtractionKey Extraction key for [[UnresolvedExtractValue]] if we are + * currently resolving one, None otherwise. + * @param lambdaVariableMap A map of lambda variable names to their corresponding + * [[NamedExpression]]s used to resolve [[UnresolvedLambdaVariable]]s inside lambda functions. */ class ExpressionResolutionContext( val parentContext: Option[ExpressionResolutionContext] = None, @@ -103,7 +114,11 @@ class ExpressionResolutionContext( var hasCorrelatedScalarSubqueryExpressions: Boolean = false, var resolvingTreeUnderAggregateExpression: Boolean = false, var resolvingCreateNamedStruct: Boolean = false, - var resolvingUnresolvedAlias: Boolean = false) { + var resolvingUnresolvedAlias: Boolean = false, + var resolvingPivotAggregates: Boolean = false, + var hasGroupingAnalyticsExpression: Boolean = false, + var extractValueExtractionKey: Option[Expression] = None, + var lambdaVariableMap: Option[IdentifierMap[NamedExpression]] = None) { /** * Propagate generic information that is valid across the whole expression tree from the @@ -116,6 +131,7 @@ class ExpressionResolutionContext( hasAttributeOutsideOfAggregateExpressions |= child.hasAttributeOutsideOfAggregateExpressions hasLateralColumnAlias |= child.hasLateralColumnAlias hasWindowExpressions |= child.hasWindowExpressions + hasGroupingAnalyticsExpression |= child.hasGroupingAnalyticsExpression hasCorrelatedScalarSubqueryExpressions |= child.hasCorrelatedScalarSubqueryExpressions } } @@ -135,7 +151,9 @@ object ExpressionResolutionContext { resolvingWindowFunction = parent.resolvingWindowFunction, windowFunctionNestednessLevel = parent.windowFunctionNestednessLevel, resolvingWindowSpec = parent.resolvingWindowSpec, - resolvingUnresolvedAlias = parent.resolvingUnresolvedAlias + resolvingUnresolvedAlias = parent.resolvingUnresolvedAlias, + resolvingPivotAggregates = parent.resolvingPivotAggregates, + lambdaVariableMap = parent.lambdaVariableMap ) } else { new ExpressionResolutionContext( @@ -147,7 +165,9 @@ object ExpressionResolutionContext { windowFunctionNestednessLevel = parent.windowFunctionNestednessLevel, resolvingWindowSpec = parent.resolvingWindowSpec, resolvingCreateNamedStruct = parent.resolvingCreateNamedStruct, - resolvingUnresolvedAlias = parent.resolvingUnresolvedAlias + resolvingUnresolvedAlias = parent.resolvingUnresolvedAlias, + resolvingPivotAggregates = parent.resolvingPivotAggregates, + lambdaVariableMap = parent.lambdaVariableMap ) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolver.scala index 7cb7184f303c1..306c8513ee324 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolver.scala @@ -43,6 +43,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{ Aggregate, Filter, LogicalPlan, + Pivot, Project, Sort } @@ -133,7 +134,6 @@ class ExpressionResolver( private val scopes = resolver.getNameScopes private val subqueryRegistry = resolver.getSubqueryRegistry private val operatorResolutionContextStack = resolver.getOperatorResolutionContextStack - private val aliasResolver = new AliasResolver(this) private val timezoneAwareExpressionResolver = new TimezoneAwareExpressionResolver(this) private val binaryArithmeticResolver = new BinaryArithmeticResolver(this) @@ -221,6 +221,53 @@ class ExpressionResolver( resolvedExpression } + /** + * Resolve [[Pivot.aggregates]] expressions. This method resolves each expression using + * [[resolveExpressionTreeInOperatorImpl]]. We set the `resolvingPivotAggregates` flag to true + * to indicate that we are resolving pivot aggregates. + */ + def resolvePivotAggregates(pivot: Pivot): Seq[Expression] = { + pivot.aggregates.map { expression => + val (resolved, _) = resolveExpressionTreeInOperatorImpl( + unresolvedExpression = expression, + parentOperator = pivot, + resolvingPivotAggregates = true + ) + resolved + } + } + + /** + * Resolve [[Unpivot.values]] or [[Unpivot.ids]] expressions. This method first expands [[Star]] + * expressions, then resolves each expression using [[resolveExpressionTreeInOperatorImpl]]. + * We set the `shouldPreserveAlias` flag to true since both [[Unpivot.values]] and + * [[Unpivot.ids]] are sequences of [[NamedExpression]]s. + */ + def resolveUnpivotArguments( + arguments: Seq[Expression], + unpivot: LogicalPlan): Seq[NamedExpression] = { + val argumentsWithStarsExpanded = traversals.withNewTraversal(unpivot) { + expandStarExpressions(arguments) + } + + argumentsWithStarsExpanded.map { argument => + val (resolvedExpression, _) = resolveExpressionTreeInOperatorImpl( + parentOperator = unpivot, + unresolvedExpression = argument, + shouldPreserveAlias = true + ) + resolvedExpression.asInstanceOf[NamedExpression] + } + } + + /** + * Expand [[Star]] expressions in the given sequence of expressions. + */ + def expandStarExpressions(expressions: Seq[Expression]): Seq[Expression] = expressions.flatMap { + case star: Star => expandStar(star) + case other => Seq(other) + } + /** * This method is an expression analysis entry point. The method first checks if the expression * has already been resolved (necessary because of partially-unresolved subtrees, see @@ -603,7 +650,8 @@ class ExpressionResolver( unresolvedExpression: Expression, parentOperator: LogicalPlan, shouldPreserveAlias: Boolean = false, - resolvingGroupingExpressions: Boolean = false + resolvingGroupingExpressions: Boolean = false, + resolvingPivotAggregates: Boolean = false ): (Expression, ExpressionResolutionContext) = { traversals.withNewTraversal( parentOperator = parentOperator, @@ -615,7 +663,8 @@ class ExpressionResolver( new ExpressionResolutionContext( isRoot = true, shouldPreserveAlias = shouldPreserveAlias, - resolvingGroupingExpressions = resolvingGroupingExpressions + resolvingGroupingExpressions = resolvingGroupingExpressions, + resolvingPivotAggregates = resolvingPivotAggregates ) ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExtractValueResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExtractValueResolver.scala new file mode 100644 index 0000000000000..dad863f20aed6 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExtractValueResolver.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, ExtractValue} + +/** + * Resolver for [[UnresolvedExtractValue]]. Resolve the [[UnresolvedExtractValue]] by resolving its + * extraction key and child. + * + * [[UnresolvedExtractValue]] is constructed in parser, when indexing an + * array or a map with a key. For example, in the following query: + * + * {{{ + * SELECT col1.a, col2[0], col3['a'] + * FROM VALUES (named_struct('a', 1, 'b', 2), array(1,2,3), map('a', 1, 'b', 2)); + * }}} + * + * - `col1.a` is parsed as [[UnresolvedAttribute]] + * - `col2[0]` is parsed as [[UnresolvedExtractValue]] with `col2` as child and `0` as extraction + * key. + * - `col3['a']` is parsed as [[UnresolvedExtractValue]] with `col3` as child and `'a'` as + * extraction key. + */ +class ExtractValueResolver(expressionResolver: ExpressionResolver) + extends TreeNodeResolver[UnresolvedExtractValue, Expression] + with CoercesExpressionTypes { + private val traversals = expressionResolver.getExpressionTreeTraversals + private val expressionResolutionContextStack = + expressionResolver.getExpressionResolutionContextStack + + /** + * Resolves [[UnresolvedExtractValue]] by first resolving its extraction key and then its child. + * After resolving extraction key, we put the resolved key in current resolution context in order + * to allow [[NameScope]] to resolve attributes that are inside the [[ExtractValue]] expression. + * + * Handle the resolved [[ExtractValue]] by type coercing it and collecting it for window + * resolution, if needed. + */ + def resolve(unresolvedExtractValue: UnresolvedExtractValue): Expression = { + val resolvedExtractionKey = expressionResolver.resolve(unresolvedExtractValue.extraction) + + expressionResolutionContextStack.peek().extractValueExtractionKey = Some(resolvedExtractionKey) + val resolvedChild = try { + expressionResolver.resolve(unresolvedExtractValue.child) + } finally { + expressionResolutionContextStack.peek().extractValueExtractionKey = None + } + + val resolvedExtractValue = ExtractValue.apply( + child = resolvedChild, + extraction = resolvedExtractionKey, + resolver = conf.resolver + ) + + resolvedExtractValue match { + case extractValue: ExtractValue => handleResolvedExtractValue(extractValue) + case other => other + } + } + + /** + * Coerces recursive types ([[ExtractValue]] expressions) in a bottom up manner and collects + * attribute references required for window resolution. For example: + * + * {{{ + * CREATE OR REPLACE TABLE t(col MAP); + * SELECT col.field FROM t; + * }}} + * + * In this example we need to cast inner field from `String` to `BIGINT`, thus analyzed plan + * should look like: + * + * {{{ + * Project [col#x[cast(field as bigint)] AS field#x] + * +- SubqueryAlias spark_catalog.default.t + * +- Relation spark_catalog.default.t[col#x] parquet + * }}} + * + * This is needed to stay compatible with the fixed-point implementation. + */ + def handleResolvedExtractValue(extractValue: ExtractValue): Expression = { + extractValue.transformUp { + case attributeReference: AttributeReference => + coerceExpressionTypes(attributeReference, traversals.current) + case field => + coerceExpressionTypes(field, traversals.current) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/FunctionResolverUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/FunctionResolverUtils.scala new file mode 100644 index 0000000000000..503c94fc9cdf6 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/FunctionResolverUtils.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import java.util.Locale + +import org.apache.spark.sql.catalyst.analysis.{ + ResolvedStar, + Star, + UnresolvedFunction, + UnresolvedStar +} +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.internal.SQLConf + +/** + * Trait with utility methods shared between [[FunctionResolver]] and + * [[HigherOrderFunctionResolver]]. + */ +trait FunctionResolverUtils { + protected def expressionResolver: ExpressionResolver + protected def conf: SQLConf + + private val scopes = expressionResolver.getNameScopes + + /** + * Expand all star expressions in arguments. Separately handles 2 cases with count function: + * - `count(*)` is replaced with `count(1)`. See [[normalizeCountExpression]] + * + * - `count(table.*)` throws an exception if the flag + * [[SQLConf.ALLOW_STAR_WITH_SINGLE_TABLE_IDENTIFIER_IN_COUNT]] is false. + * (see [[assertSingleTableStarNotInCountFunction]]) + * It is done to avoid confusion since `count(*)` and `count(table.*)` would produce + * different results: + * - `count(*)` returns the number of rows + * - `count(table.*)` returns the number of rows where all columns are not null. It's the same + * behavior as if explicitly listing all columns of the table in count. + * + * Returns [[UnresolvedFunction]] without any star expressions in arguments. + */ + protected def handleStarInArguments( + unresolvedFunction: UnresolvedFunction): UnresolvedFunction = { + val functionContainsStarInArguments = unresolvedFunction.arguments.exists { + case _: Star => true + case _ => false + } + + if (!functionContainsStarInArguments) { + unresolvedFunction + } else if (isNonDistinctCount(unresolvedFunction) && + hasSingleSimpleStarArgument(unresolvedFunction)) { + normalizeCountExpression(unresolvedFunction) + } else { + assertSingleTableStarNotInCountFunction(unresolvedFunction) + unresolvedFunction.copy( + arguments = expressionResolver.expandStarExpressions(unresolvedFunction.arguments) + ) + } + } + + /** + * Check if the given unresolved function has one `*` as argument. Usually used to detect cases + * where `count(*)` should be replaced with `count(1)`. + * + * Returns True for [[ResolvedStar]] and [[UnresolvedStar]] without specified target. + * + * Note that it's False even for other implementation of [[Star]] trait, for example + * [[UnresolvedStarExceptOrReplace]] (`* except ...`) or [[UnresolvedStar]] with specified + * target (`table.*`). + */ + private def hasSingleSimpleStarArgument(unresolvedFunction: UnresolvedFunction): Boolean = + unresolvedFunction.arguments match { + case Seq(UnresolvedStar(None)) => true + case Seq(_: ResolvedStar) => true + case _ => false + } + + /** + * Method used to determine whether the given function is non-distinct `count` function, + * with optional normalization. + */ + private def isNonDistinctCount( + unresolvedFunction: UnresolvedFunction, + normalizeFunctionName: Boolean = true + ): Boolean = { + !unresolvedFunction.isDistinct && isCount(unresolvedFunction, normalizeFunctionName) + } + + private def isCount( + unresolvedFunction: UnresolvedFunction, + normalizeFunctionName: Boolean = true + ): Boolean = { + val isCountName = if (normalizeFunctionName) { + unresolvedFunction.nameParts.head.toLowerCase(Locale.ROOT) == "count" + } else { + unresolvedFunction.nameParts.head == "count" + } + + unresolvedFunction.nameParts.length == 1 && isCountName + } + + /** + * Method used to replace the `count(*)` function with `count(1)` function. Resolution of the + * `count(*)` is done in the following way: + * - SQL: It is done during the construction of the AST (in [[AstBuilder]]). + * - Dataframes: It is done during the analysis phase and that's why we need to do it here. + */ + private def normalizeCountExpression( + unresolvedFunction: UnresolvedFunction): UnresolvedFunction = { + unresolvedFunction.copy( + nameParts = Seq("count"), + arguments = Seq(Literal(1)), + filter = unresolvedFunction.filter + ) + } + + /** + * Throws an exception according to [[SQLConf.ALLOW_STAR_WITH_SINGLE_TABLE_IDENTIFIER_IN_COUNT]]. + * + * Note that check for function name is case-sensitive. Even when flag is false we allow + * `COUNT(tableName.*)` but block `count(tableName.*)`. This is the same behavior as in + * fixed-point analyzer, and clearly it is a bug. If this is ever fixed it should be done in both + * analyzers simultaneously. + * + * See [[handleStarInArguments]] + */ + private def assertSingleTableStarNotInCountFunction( + unresolvedFunction: UnresolvedFunction): Unit = { + if (!conf.allowStarWithSingleTableIdentifierInCount && isCount( + unresolvedFunction = unresolvedFunction, + normalizeFunctionName = false + ) && unresolvedFunction.arguments.length == 1) { + unresolvedFunction.arguments.head match { + case star: UnresolvedStar if scopes.current.isStarQualifiedByTable(star) => + throw QueryCompilationErrors + .singleTableStarInCountNotAllowedError(star.target.get.mkString(".")) + case _ => + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/GroupingAnalyticsResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/GroupingAnalyticsResolver.scala new file mode 100644 index 0000000000000..6ba98677ab2c8 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/GroupingAnalyticsResolver.scala @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import org.apache.spark.sql.catalyst.analysis.GroupingAnalyticsTransformer +import org.apache.spark.sql.catalyst.expressions.{ + AttributeReference, + BaseGroupingSets, + Expression, + GroupingAnalyticsExtractor, + SortOrder, + VirtualColumn +} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, Sort} +import org.apache.spark.sql.errors.QueryCompilationErrors + +/** + * Resolves [[Aggregate]] node with grouping analytics (i.e., ROLLUP, CUBE, GROUPING SETS). + */ +class GroupingAnalyticsResolver(resolver: Resolver, expressionResolver: ExpressionResolver) + extends TreeNodeResolver[Aggregate, Aggregate] { + private val scopes: NameScopeStack = resolver.getNameScopes + private val autoGeneratedAliasProvider = expressionResolver.getAutoGeneratedAliasProvider + private val operatorResolutionContextStack = resolver.getOperatorResolutionContextStack + + /** + * Resolution of the [[Aggregate]] is done in following steps: + * 1. Extract grouping analytics from grouping expressions using [[GroupingAnalyticsExtractor]]. + * 2. Transform the [[Aggregate]] node with grouping analytics using + * [[GroupingAnalyticsTransformer]] (see its scala doc for more details). + * 3. If the child of the transformed [[Aggregate]] is an [[Expand]] node, update its output + * attributes to have correct expression IDs. + * + * For example for this query: + * + * {{{ + * SELECT col1 FROM values(1) GROUP BY grouping sets ((col1), ()); + * }}} + * + * The parsed plan would be: + * + * {{{ + * 'Aggregate [groupingsets(Vector(0), Vector(), 'col1)], ['col1] + * +- LocalRelation [col1#0] + * }}} + * + * Whereas the resolved plan would be: + * + * {{{ + * Aggregate [col1#3, spark_grouping_id#2L], [col1#3] + * +- Expand [[col1#0, col1#1, 0], [col1#0, null, 1]], [col1#0, col1#3, spark_grouping_id#2L] + * +- Project [col1#0, col1#0 AS col1#1] + * +- LocalRelation [col1#0] + * }}} + * + * As it can be seen [[Expand]] node propagates two `col1` attributes which have different + * properties. The first one (`col1#0`) comes from the child of the original [[Aggregate]] and it + * is passed unchanged. The second one (`col1#1`) is generated by the + * [[GroupingAnalyticsTransformer]] to represent `col1` in the grouping set and it has updated + * nullability. It's nullability is set to `true` since in the grouping set `col1` can be absent + * (i.e., null) as in the second grouping set `()`. The second one is used later to resolve `col1` + * in the [[Aggregate]] node. + */ + override def resolve(aggregate: Aggregate): Aggregate = { + val groupingAnalytics = tryExtractGroupingAnalytics(aggregate) + + groupingAnalytics match { + case Some((selectedGroupByExpressions, groupByExpressions)) => + operatorResolutionContextStack.current.hasGroupingAnalytics = true + + val transformedAggregate = GroupingAnalyticsTransformer( + newAlias = (child, name, qualifier) => + autoGeneratedAliasProvider.newAlias(child, name, qualifier = qualifier), + childOutput = scopes.current.output, + groupByExpressions = groupByExpressions, + selectedGroupByExpressions = selectedGroupByExpressions, + child = aggregate.child, + aggregationExpressions = aggregate.aggregateExpressions + ) + + val newAggregateChild = transformedAggregate.child match { + case expand: Expand => + handleExpandBelowAggregate(expand) + case other => other + } + + transformedAggregate.copy(child = newAggregateChild) + case None => + aggregate + } + } + + /** + * Handles [[SortOrder]] expressions that contain grouping analytics expressions. This is done by: + * 1. Validating that there is a base [[Aggregate]] in the current scope: + * - If there is, collect grouping expressions from it using + * [[GroupingAnalyticsTransformer.collectGroupingExpressions]]. + * - If there isn't, throw `groupingMustWithGroupingSetsOrCubeOrRollupError` exception. + * 2. Create a grouping ID attribute and resolve it using [[ExpressionResolver]]. + * 3. Replace [[SortOrder]] expressions using + * [[GroupingAnalyticsTransformer.replaceGroupingFunction]] (see its scala doc for more + * details). + */ + def handleSortOrderExpressionsWithGroupingAnalytics( + orderExpressions: Seq[SortOrder], + sort: Sort): Seq[SortOrder] = { + val groupingExpressions = if (scopes.current.baseAggregate.isDefined) { + GroupingAnalyticsTransformer.collectGroupingExpressions( + scopes.current.baseAggregate.get + ) + } else { + throw QueryCompilationErrors.groupingMustWithGroupingSetsOrCubeOrRollupError() + } + + val groupingId = VirtualColumn.groupingIdAttribute + val resolvedGroupingId = expressionResolver + .resolveExpressionTreeInOperator(groupingId, sort) + .asInstanceOf[AttributeReference] + + val orderExpressionsWithGroupingAnalytics = orderExpressions.map { orderExpression => + GroupingAnalyticsTransformer + .replaceGroupingFunction( + expression = orderExpression, + groupByExpressions = groupingExpressions, + gid = resolvedGroupingId, + newAlias = (child, name, qualifier) => + autoGeneratedAliasProvider.newAlias(child, name, qualifier = qualifier) + ) + .asInstanceOf[SortOrder] + } + + orderExpressionsWithGroupingAnalytics + } + + /** + * Extract expressions that contain grouping analytics operations from an + * [[Aggregate.groupingExpressions]] using [[GroupingAnalyticsExtractor]]. See its scala doc for + * more details. + */ + private def tryExtractGroupingAnalytics( + aggregate: Aggregate): Option[(Seq[Seq[Expression]], Seq[Expression])] = { + aggregate.groupingExpressions match { + case exprs if exprs.exists(_.isInstanceOf[BaseGroupingSets]) => + GroupingAnalyticsExtractor(exprs) + case _ => + None + } + } + + /** + * Handles [[Expand]] node below an [[Aggregate]]. This includes: + * - Updating output attributes to have correct expression IDs. + * - Overwriting the current scope output and extending hidden output with the new output + * since operators above the [[Aggregate]] might resolve attributes using hidden output + * produces by the [[Expand]]. + * - Copying tags from the original [[Expand]]. + */ + private def handleExpandBelowAggregate(expand: Expand): Expand = { + val mappedOutput = expand.output.map { attribute => + expressionResolver.getExpressionIdAssigner + .mapExpression(attribute, allowUpdatesForAttributeReferences = true) + } + + val newExpand = expand.copy(output = mappedOutput) + + scopes.overwriteOutputAndExtendHiddenOutput(output = newExpand.output) + newExpand.copyTagsFrom(expand) + + newExpand + } +} + +/** + * Helper object for [[GroupingAnalyticsResolver]]. + */ +object GroupingAnalyticsResolver { + + /** + * In case there are grouping analytics outside of an [[Aggregate]] node (e.g. [[Filter]]), throw + * an [[ExplicitlyUnsupportedResolverFeature]] exception. + */ + def restrictGroupingAnalyticsBelowSortAndFilter( + operatorResolutionContext: OperatorResolutionContext): Unit = { + if (operatorResolutionContext.hasGroupingAnalytics) { + throw new ExplicitlyUnsupportedResolverFeature("grouping analytics outside of Aggregate") + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/HigherOrderFunctionResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/HigherOrderFunctionResolver.scala new file mode 100644 index 0000000000000..6b90a5c05baf1 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/HigherOrderFunctionResolver.scala @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import org.apache.spark.sql.catalyst.analysis.{ + AnalysisErrorAt, + FunctionResolution, + LambdaBinder, + SubqueryExpressionInLambdaOrHigherOrderFunctionValidator, + UnresolvedFunction +} +import org.apache.spark.sql.catalyst.analysis.TypeCoercion.MapZipWithTypeCoercion +import org.apache.spark.sql.catalyst.expressions.{Expression, HigherOrderFunction} +import org.apache.spark.sql.errors.QueryCompilationErrors + +/** + * A resolver specifically for higher-order functions (functions that accept lambda expressions). + */ +class HigherOrderFunctionResolver( + protected val expressionResolver: ExpressionResolver, + functionResolution: FunctionResolution) + extends TreeNodeResolver[UnresolvedFunction, Expression] + with ProducesUnresolvedSubtree + with CoercesExpressionTypes + with FunctionResolverUtils { + + private val traversals = expressionResolver.getExpressionTreeTraversals + + /** + * Resolves an [[UnresolvedFunction]] representing a higher-order function into a fully resolved + * [[HigherOrderFunction]] expression. This is done in the following steps: + * 1. Handle any star (`*`) arguments in the function call. See more in the + * [[FunctionResolverUtils.handleStarInArguments]] scala doc. + * 2. Resolve the function to a built-in or temporary function. + * 3. Validate that the resolved function is indeed a higher-order function. Otherwise, throw an + * exception. + * 4. Recursively resolve all arguments of the higher-order function. + * 5. Apply the [[MapZipWithTypeCoercion]] type coercion transformation to the higher-order + * function. It must be done before binding the higher-order function so that the + * [[MapZipWith]] arguments have the correct types. + * 6. Validate that there are no subquery expressions within the higher-order function. If there + * are, throw an exception. + * 7. Bind the higher-order function using the [[LambdaBinder]]. See more in the [[LambdaBinder]] + * scala doc. + * 8. Recursively resolve all lambda functions within the higher-order function. + * 9. Coerce the higher-order function. + * 10. Apply a time zone if needed. + * + * In case of the following query: + * + * {{{ + * SELECT filter(array('a'), x -> x = 'a'); + * }}} + * + * The parsed plan would be: + * + * {{{ + * 'Project [unresolvedalias('filter('array(a), lambdafunction((lambda 'x = a), lambda 'x)))] + * +- OneRowRelation + * }}} + * + * Whereas the analyzed plan is: + * + * {{{ + * Project [filter(array(a), lambdafunction((lambda x#0 = a), lambda x#0)) AS ...] + * +- OneRowRelation + * }}} + * + * For a query with star (`*`) expansion: + * + * {{{ + * SELECT transform(array(*), x -> x + 1) FROM VALUES (1); + * }}} + * + * The parsed plan would be: + * + * {{{ + * 'Project [unresolvedalias('transform('array(*), lambdafunction((lambda 'x + 1), lambda 'x)))] + * +- SubqueryAlias t + * +- LocalRelation [col1#0] + * }}} + * + * Whereas the analyzed plan (after star expansion) is: + * + * {{{ + * Project [transform(array(col1#0), lambdafunction((lambda x#3 + 1), lambda x#3)) AS ...] + * +- SubqueryAlias t + * +- LocalRelation [col1#0] + * }}} + */ + override def resolve(unresolvedFunction: UnresolvedFunction): Expression = { + val unresolvedFunctionWithExpandedStarArgs = handleStarInArguments(unresolvedFunction) + + val partiallyResolvedFunction = functionResolution.resolveBuiltinOrTempFunction( + name = unresolvedFunctionWithExpandedStarArgs.nameParts, + arguments = unresolvedFunctionWithExpandedStarArgs.arguments, + u = unresolvedFunctionWithExpandedStarArgs + ) + + val partiallyResolvedHigherOrderFunction = + validateHigherOrderFunction(partiallyResolvedFunction, unresolvedFunctionWithExpandedStarArgs) + + val resolvedArguments = partiallyResolvedHigherOrderFunction.arguments.map { arg => + expressionResolver.resolve(arg) + } + + val higherOrderFunctionWithResolvedArguments = + partiallyResolvedHigherOrderFunction + .withNewChildren(resolvedArguments ++ partiallyResolvedHigherOrderFunction.functions) + + val functionWithCoercedArguments = + MapZipWithTypeCoercion(higherOrderFunctionWithResolvedArguments) + .asInstanceOf[HigherOrderFunction] + + SubqueryExpressionInLambdaOrHigherOrderFunctionValidator(functionWithCoercedArguments) + + val boundHigherOrderFunction = functionWithCoercedArguments.bind( + (expression, argumentsInfo) => LambdaBinder(expression, argumentsInfo) + ) + + val resolvedFunctions = boundHigherOrderFunction.functions.map { func => + expressionResolver.resolve(func) + } + + val higherOrderFunctionWithResolvedChildren = + boundHigherOrderFunction.withNewChildren(resolvedArguments ++ resolvedFunctions) + + val resolvedHigherOrderFunction = coerceExpressionTypes( + expression = higherOrderFunctionWithResolvedChildren, + expressionTreeTraversal = traversals.current + ) + + TimezoneAwareExpressionResolver.resolveTimezone( + expression = resolvedHigherOrderFunction, + timeZoneId = traversals.current.sessionLocalTimeZone + ) + } + + private def validateHigherOrderFunction( + partiallyResolvedFunction: Option[Expression], + unresolvedFunction: UnresolvedFunction): HigherOrderFunction = { + partiallyResolvedFunction match { + case Some(higherOrderFunction: HigherOrderFunction) => higherOrderFunction + case Some(other) => + other.failAnalysis( + errorClass = "INVALID_LAMBDA_FUNCTION_CALL.NON_HIGHER_ORDER_FUNCTION", + messageParameters = Map("class" -> other.getClass.getCanonicalName) + ) + case None => + throw QueryCompilationErrors.unresolvedRoutineError( + unresolvedFunction.nameParts, + Seq("system.builtin", "system.session"), + unresolvedFunction.origin + ) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/IdentifierFromUnresolvedNodeExtractor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/IdentifierFromUnresolvedNodeExtractor.scala new file mode 100644 index 0000000000000..7f6409af19d02 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/IdentifierFromUnresolvedNodeExtractor.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import java.util.Locale + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.{RelationResolution, UnresolvedRelation} +import org.apache.spark.sql.connector.catalog.{CatalogManager, LookupCatalog} + +class IdentifierFromUnresolvedNodeExtractor( + override val catalogManager: CatalogManager, + relationResolution: RelationResolution) + extends LookupCatalog { + def apply(unresolvedRelation: UnresolvedRelation): Option[TableIdentifier] = { + relationResolution.expandIdentifier(unresolvedRelation.multipartIdentifier) match { + case CatalogAndIdentifier(catalog, identifier) => + Some( + TableIdentifier( + catalog = Some(catalog.name().toLowerCase(Locale.ROOT)), + database = Some(identifier.namespace().head.toLowerCase(Locale.ROOT)), + table = identifier.name().toLowerCase(Locale.ROOT) + ) + ) + case _ => + None + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LambdaFunctionResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LambdaFunctionResolver.scala new file mode 100644 index 0000000000000..788ff6e702d6c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LambdaFunctionResolver.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import org.apache.spark.sql.catalyst.expressions.{Expression, LambdaFunction, NamedExpression} + +/** + * Resolver for [[LambdaFunction]] expressions. + */ +class LambdaFunctionResolver(expressionResolver: ExpressionResolver) + extends TreeNodeResolver[LambdaFunction, Expression] + with ResolvesExpressionChildren { + private val expressionResolutionContextStack = + expressionResolver.getExpressionResolutionContextStack + + /** + * Resolves the [[LambdaFunction]] based on the type of the lambda function. There are three + * cases: + * 1. If the lambda function is not hidden, its arguments are added to the current + * [[ExpressionResolutionContext.lambdaVariableMap]] so that they can be used for resolution + * of the [[UnresolvedNamedLambdaVariables]] later in the [[ExpressionResolver]]. Hidden + * functions are the ones that are created internally when the user specifies non-lambda + * argument (instead provide an actual lambda function because it's semantically more right). + * In the following example we would replace each array element with the constant `0`, + * resulting in `array(0)`. Spark internally creates a hidden lambda since the user didn't + * provide one: + * + * {{{ + * SELECT transform(array(1), 0); + * }}} + * + * Here we would have `lambdafunction(0, lambda col0#0, true)` where the original `0` + * would be the `function` and `NamedLambdaVariable(col0#0)` would be the `argument`. + * + * 2. If the lambda function is hidden, it is resolved by resolving its children without adding + * its arguments to the current [[ExpressionResolutionContext.lambdaVariableMap]]. + * + * In case there is a query: + * + * {{{ + * SELECT filter(array('a'), x -> x = 'a'); + * }}} + * + * Lambda function would be `x -> x = 'a'`, which is not hidden. So, `x` (left side one) would be + * added to the current [[ExpressionResolutionContext.lambdaVariableMap]] and used later when + * resolving the `x = 'a'` expression (`x` is an [[UnresolvedNamedLambdaVariable]] at that + * point). + * + * In case of nested lambdas: + * + * {{{ + * SELECT + * transform( + * nested_arrays, + * inner_array -> aggregate( + * inner_array, + * 1, + * (product, element) -> product * element * size(inner_array) + * ) + * ) + * FROM VALUES ( + * array(array(1, 2), array(3, 4)) + * ) AS t(nested_arrays); + * }}} + * + * We need to pass the `inner_array` lambda variable to the inner lambda function resolver + * because it is used inside the inner lambda function body. So, while resolving the `aggregate` + * lambda function the `lambdaVariableMap` would look like: + * - 'product' -> NamedLambdaVariable(product) + * - 'element' -> NamedLambdaVariable(element) + * - 'inner_array' -> NamedLambdaVariable(inner_array) + */ + override def resolve(lambdaFunction: LambdaFunction): Expression = { + val expressionResolutionContext = expressionResolutionContextStack.peek() + val previousLambdaVariableMap = expressionResolutionContext.lambdaVariableMap.map { + existingMap => + existingMap.copyTo(new IdentifierMap[NamedExpression]) + } + + try { + if (!lambdaFunction.hidden) { + resolveUnhiddenLambdaFunction( + lambdaFunction = lambdaFunction, + expressionResolutionContext = expressionResolutionContext + ) + } else { + withResolvedChildren( + unresolvedExpression = lambdaFunction, + resolveChild = expressionResolver.resolve _ + ) + } + } finally { + expressionResolutionContext.lambdaVariableMap = previousLambdaVariableMap + } + } + + private def resolveUnhiddenLambdaFunction( + lambdaFunction: LambdaFunction, + expressionResolutionContext: ExpressionResolutionContext): Expression = { + val lambdaMap = expressionResolutionContext.lambdaVariableMap + .map(_.copyTo(new IdentifierMap[NamedExpression])) + .getOrElse(new IdentifierMap[NamedExpression]) + lambdaFunction.arguments.foreach(argument => lambdaMap += (argument.name, argument)) + expressionResolutionContext.lambdaVariableMap = Some(lambdaMap) + + withResolvedChildren( + unresolvedExpression = lambdaFunction, + resolveChild = expressionResolver.resolve _ + ) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LogicalPlanDifference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LogicalPlanDifference.scala new file mode 100644 index 0000000000000..c08977f9efb53 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LogicalPlanDifference.scala @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import java.lang.StringBuilder +import java.util.ArrayDeque + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +/** + * Utility object for finding differences between logical plans and generating truncated + * representations showing context around the first mismatch. + */ +object LogicalPlanDifference { + + /** + * Generate truncated versions of two logical plans showing context around the first mismatch. + * + * This method finds the first line where the string representations of the two plans differ, + * then returns truncated versions of both plans containing: + * - Up to contextSize lines before the first mismatch + * - The mismatched line itself + * - Up to contextSize lines after the mismatch + * + * If the plans are identical, returns empty strings for both plans. + * + * This method performs a simple line-by-line comparison of the string representations of the + * plans. It does not perform any semantic analysis or structural comparison. The semantic + * comparison should be done beforehand. + * + * @param lhsPlan the left-hand side logical plan to compare + * @param rhsPlan the right-hand side logical plan to compare + * @param contextSize the number of lines to show before and after the mismatch + * @return a tuple of (truncatedLhsPlan, truncatedRhsPlan) + */ + def apply(lhsPlan: LogicalPlan, rhsPlan: LogicalPlan, contextSize: Int): (String, String) = { + val lhsPlanString = lhsPlan.toString + val rhsPlanString = rhsPlan.toString + + val lhsIterator = new LineIterator(lhsPlanString) + val rhsIterator = new LineIterator(rhsPlanString) + + val lhsBuffer = new CircularBuffer(contextSize + 1) + val rhsBuffer = new CircularBuffer(contextSize + 1) + + var mismatchFound = false + + while (!mismatchFound && (lhsIterator.hasNext || rhsIterator.hasNext)) { + val lhsLine = if (lhsIterator.hasNext) Some(lhsIterator.next()) else None + val rhsLine = if (rhsIterator.hasNext) Some(rhsIterator.next()) else None + + lhsLine.foreach(lhsBuffer.add) + rhsLine.foreach(rhsBuffer.add) + + mismatchFound = checkLineDifference( + lhsLine = lhsLine, + rhsLine = rhsLine, + lhsSource = lhsPlanString, + rhsSource = rhsPlanString + ) + } + + if (!mismatchFound) { + ("", "") + } else { + val lhsResult = new StringBuilder() + val rhsResult = new StringBuilder() + + lhsBuffer.foreach { position => + lhsResult.append(lhsPlanString, position.start, position.end).append('\n') + } + rhsBuffer.foreach { position => + rhsResult.append(rhsPlanString, position.start, position.end).append('\n') + } + + var linesAfter = 0 + while (linesAfter < contextSize && (lhsIterator.hasNext || rhsIterator.hasNext)) { + if (lhsIterator.hasNext) { + val position = lhsIterator.next() + lhsResult.append(lhsPlanString, position.start, position.end).append('\n') + } + if (rhsIterator.hasNext) { + val position = rhsIterator.next() + rhsResult.append(rhsPlanString, position.start, position.end).append('\n') + } + + linesAfter += 1 + } + + (lhsResult.toString(), rhsResult.toString()) + } + } + + /** + * Checks if two line positions represent different content by comparing the corresponding + * regions in the source strings without materializing the line strings. + * + * @param lhsLine the left-hand side line position (optional) + * @param rhsLine the right-hand side line position (optional) + * @param lhsSource the source string for the left-hand side line + * @param rhsSource the source string for the right-hand side line + * @return true if the lines differ, false if they are the same + */ + private def checkLineDifference( + lhsLine: Option[LinePosition], + rhsLine: Option[LinePosition], + lhsSource: String, + rhsSource: String): Boolean = { + (lhsLine, rhsLine) match { + case (Some(lhsPosition), Some(rhsPosition)) => + val lhsLength = lhsPosition.end - lhsPosition.start + val rhsLength = rhsPosition.end - rhsPosition.start + lhsLength != rhsLength || !lhsSource.regionMatches( + lhsPosition.start, + rhsSource, + rhsPosition.start, + lhsLength + ) + case (None, None) => false + case _ => true + } + } + + /** + * Represents a line position in a string by its start and end indices, avoiding string + * allocation until the line content is actually needed. + * + * @param start the starting index of the line (inclusive) + * @param end the ending index of the line (exclusive) + */ + private case class LinePosition(start: Int, end: Int) + + /** + * Iterator that iterates over lines in a string, returning line positions (start and end + * indices) instead of materialized strings to avoid allocations. + * + * @param str the string to iterate over + */ + private class LineIterator(str: String) extends Iterator[LinePosition] { + private var position = 0 + private val stringLength = str.length + + /** + * Returns true if there are more lines to read from the string. + * + * @return true if the iterator has more elements, false otherwise + */ + override def hasNext: Boolean = position < stringLength + + /** + * Returns the position (start and end indices) of the next line from the string. A line is + * defined as a sequence of characters ending with a newline character or the end of the + * string. The newline character itself is not included in the line position. + * + * @return the position of the next line as (start, end) indices + * @throws NoSuchElementException if there are no more lines to read + */ + override def next(): LinePosition = { + if (!hasNext) throw new NoSuchElementException("next on empty iterator") + + val start = position + while (position < stringLength && str.charAt(position) != '\n') { + position += 1 + } + + val linePosition = LinePosition(start, position) + if (position < stringLength) position += 1 + + linePosition + } + } + + /** + * Circular buffer that maintains the last N line positions added. + * + * @param capacity the maximum number of line positions to retain in the buffer + */ + private class CircularBuffer(capacity: Int) { + private val buffer = new ArrayDeque[LinePosition](capacity) + + /** + * Adds a line position to the buffer. If the buffer is at capacity, the oldest element is + * removed before adding the new element. If capacity is 0, the line position is not added. + * + * @param linePosition the line position to add to the buffer + */ + def add(linePosition: LinePosition): Unit = { + if (capacity > 0) { + if (buffer.size() >= capacity) { + buffer.removeFirst() + } + buffer.addLast(linePosition) + } + } + + /** + * Applies a function to each line position in the buffer in the order they were added. + * + * @param f the function to apply to each line position + */ + def foreach(f: LinePosition => Unit): Unit = { + val iterator = buffer.iterator() + while (iterator.hasNext) { + f(iterator.next()) + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/NameParameterizedQueryResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/NameParameterizedQueryResolver.scala new file mode 100644 index 0000000000000..861d0aa744a18 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/NameParameterizedQueryResolver.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import java.util.LinkedHashMap + +import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.analysis.{AnalysisErrorAt, NameParameterizedQuery} +import org.apache.spark.sql.catalyst.expressions.{ + Alias, + CreateArray, + CreateMap, + CreateNamedStruct, + Expression, + Literal, + MapFromArrays, + MapFromEntries, + VariableReference +} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SupervisingCommand} + +/** + * Resolver class that resolves [[NameParameterizedQuery]] operators. + */ +class NameParameterizedQueryResolver( + operatorResolver: Resolver, + expressionResolver: ExpressionResolver) { + private val operatorResolutionContextStack = operatorResolver.getOperatorResolutionContextStack + private val scopes = operatorResolver.getNameScopes + + /** + * Resolves a [[NameParameterizedQuery]] operator. It handles two cases: + * 1. If the child of the [[NameParameterizedQuery]] is a [[SupervisingCommand]], it moves the + * [[NameParameterizedQuery]] below the [[SupervisingCommand]]. It is resolved later when + * `SupervisingCommand.run` is called. + * 2. Otherwise, it resolves the [[NameParameterizedQuery]] by resolving its parameter + * expressions and replacing the parameter names with their corresponding values in the + * [[ExpressionResolver]]. + */ + def resolve(nameParameterizedQuery: NameParameterizedQuery): LogicalPlan = { + nameParameterizedQuery match { + case nameParameterizedQuery @ NameParameterizedQuery(command: SupervisingCommand, _, _) => + pushParameterizedQueryBelowCommand(nameParameterizedQuery, command) + case nameParameterizedQuery => + resolveNameParameterizedQuery(nameParameterizedQuery) + } + } + + /** + * Moves the [[NameParameterizedQuery]] below the given [[SupervisingCommand]]. + * This is done to ensure that the parameters are resolved in the context of the actual plan, + * whereas the [[SupervisingCommand]] is expected to be a top-level node. + * [[NameParameterizedQuery]] should be pushed down through any nested [[SupervisingCommand]]s. + * Examples: + * + * 1. Parameters below EXPLAIN command: + * {{{ + * EXPLAIN SELECT :first; + * + * -- Analyzed plan + * -- ExplainCommand 'NameParameterizedQuery [first], [1], SimpleMode + * }}} + * + * 2. Parameters below DESCRIBE command: + * {{{ + * DESCRIBE QUERY SELECT :first; + * + * -- Analyzed plan + * -- DescribeQueryCommand 'NameParameterizedQuery [first], [1] + * }}} + * + * 3. Parameters below nested commands: + * {{{ + * EXPLAIN EXPLAIN SELECT :first; + * + * -- Analyzed plan + * -- ExplainCommand ExplainCommand 'NameParameterizedQuery [a], [1], SimpleMode, SimpleMode + * }}} + * + */ + private def pushParameterizedQueryBelowCommand( + nameParameterizedQuery: NameParameterizedQuery, + command: SupervisingCommand): LogicalPlan = { + command.withTransformedSupervisedPlan { + case nestedCommand: SupervisingCommand => + pushParameterizedQueryBelowCommand(nameParameterizedQuery, nestedCommand) + case supervisedPlan => + nameParameterizedQuery.copy(child = supervisedPlan) + } + } + + /** + * Resolves a [[NameParameterizedQuery]] operator. It's done in the following steps: + * + * 1. Resolve the parameter expressions using the + * [[ExpressionResolver.resolveExpressionTreeInOperator]]. Resolution is done in a new scope + * to avoid polluting the current scope with references from the parameter expressions. + * 2. Validate that the number of parameter names matches the number of parameter values. + * 3. Create a mapping of parameter names to their corresponding resolved values and set it in + * [[OperatorResolutionContext]] in order to use it for [[NamedParameter]] resolution later. + * 4. Check that all parameter values are of allowed types. + * 5. Resolve the child operator of the [[NameParameterizedQuery]]. + */ + private def resolveNameParameterizedQuery( + nameParameterizedQuery: NameParameterizedQuery): LogicalPlan = { + val parameterNames = nameParameterizedQuery.argNames + val parameterValues = nameParameterizedQuery.argValues + + scopes.pushScope() + val resolvedParameterValues = try { + parameterValues.map { parameterValue => + expressionResolver.resolveExpressionTreeInOperator( + unresolvedExpression = parameterValue, + parentOperator = nameParameterizedQuery + ) + } + } finally { + scopes.popScope() + scopes.current.clearAvailableAliases() + } + + validateParameterNamesAndValuesSizes(parameterNames, parameterValues) + + operatorResolutionContextStack.current.parameterNamesToValues = Some( + new LinkedHashMap[String, Expression] + ) + + parameterNames.zip(resolvedParameterValues).foreach { + case (name, value) if isNotAllowed(value) => + value.failAnalysis( + errorClass = "INVALID_SQL_ARG", + messageParameters = Map("name" -> name) + ) + case (name, value) => + operatorResolutionContextStack.current.parameterNamesToValues.get.put(name, value) + } + + operatorResolver.resolve(nameParameterizedQuery.child) + } + + private def validateParameterNamesAndValuesSizes( + parameterNames: Seq[String], + parameterValues: Seq[Expression]): Unit = { + if (parameterNames.length != parameterValues.length) { + throw SparkException.internalError( + s"The number of argument names ${parameterNames.length} " + + s"must be equal to the number of argument values ${parameterValues.length}." + ) + } + } + + private def isNotAllowed(expression: Expression): Boolean = expression.exists { + case _: Literal | _: CreateArray | _: CreateNamedStruct | _: CreateMap | _: MapFromArrays | + _: MapFromEntries | _: VariableReference | _: Alias => + false + case _ => true + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/NameScope.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/NameScope.scala index a3e02ffc78890..1c9a296af9ba4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/NameScope.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/NameScope.scala @@ -219,6 +219,12 @@ class NameScope( private lazy val hiddenAttributesById: HashMap[ExprId, Attribute] = createAttributeIds(hiddenOutput) + /** + * Whether the scope can resolve names by hidden output. This is set to `true` by + * [[markScopeForHiddenOutputResolution]]. + */ + private var canResolveNameByHiddenOutput: Boolean = false + lazy val lcaRegistry: LateralColumnAliasRegistry = if (conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) { new LateralColumnAliasRegistryImpl @@ -234,6 +240,14 @@ class NameScope( private lazy val topAggregateExpressionsByAliasName: IdentifierMap[ArrayList[Alias]] = new IdentifierMap[ArrayList[Alias]] + /** + * Sets `canResolveNameByHiddenOutput` to `true`, enabling hidden output resolution for this + * scope. + */ + def markScopeForHiddenOutputResolution(): Unit = { + canResolveNameByHiddenOutput = true + } + /** * Returns new [[NameScope]] which preserves all the immutable [[NameScope]] properties but * overwrites `output`, `hiddenOutput`, `availableAliases`, `aggregateListAliases` and diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/PivotResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/PivotResolver.scala new file mode 100644 index 0000000000000..6d7525a1b72c2 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/PivotResolver.scala @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import org.apache.spark.sql.catalyst.analysis.PivotTransformer +import org.apache.spark.sql.catalyst.expressions.{ + AliasHelper, + Attribute, + Expression, + NamedExpression, + RowOrdering +} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Pivot, Project} +import org.apache.spark.sql.errors.QueryCompilationErrors + +/** + * Resolver class that resolves [[Pivot]] operators. + */ +class PivotResolver(operatorResolver: Resolver, expressionResolver: ExpressionResolver) + extends TreeNodeResolver[Pivot, LogicalPlan] + with AliasHelper + with CoercesExpressionTypes { + private val scopes: NameScopeStack = operatorResolver.getNameScopes + private val autoGeneratedAliasProvider = expressionResolver.getAutoGeneratedAliasProvider + + /** + * Resolves a [[Pivot]] operator by: + * 1. Resolving its child operator. + * 2. Resolving its group by expressions (if any). + * 3. Resolving its pivot column. + * 4. Resolving its pivot values. + * 5. Resolving its aggregate expressions. + * 6. Validating that the pivot column is orderable. + * 7. Constructing the resolved operator using [[PivotTransformer]] (see its doc for details). + * 8. Handling the resolved operator to resolve any remaining expressions. + * 9. Updating the current scope with the new output attributes. + * + * For example consider the following query: + * + * {{{ + * SELECT * FROM VALUES (2024, 1) table(c1, c2) PIVOT (sum(c1) FOR c2 IN (1 AS q1)); + * }}} + * + * The parsed plan would be: + * + * {{{ + * 'Project [*] + * +- 'Pivot 'c2, [1 AS q1#2], ['sum('c1)] + * +- SubqueryAlias table + * +- LocalRelation [c1#0, c2#1] + * }}} + * + * Here the `pivotColumn` is `c2`, the `pivotValues` is `[1 AS q1]`, and the `aggregates` is + * `sum(c1)`. + * After the resolution, the resolved plan would be: + * + * {{{ + * Project [q1#4] + * +- Project [__pivot_sum(c1) AS `sum(c1)`#3[0] AS q1#4] + * +- Aggregate [pivotfirst(c2#1, sum(c1)#2, 1, 0, 0) AS __pivot_sum(c1) AS `sum(c1)`#3] + * +- Aggregate [c2#1], [c2#1, sum(c1#0) AS sum(c1)#2] + * +- LocalRelation [c1#0, c2#1] + * }}} + */ + override def resolve(pivot: Pivot): LogicalPlan = { + val resolvedPivot = { + val resolvedChild = operatorResolver.resolve(pivot.child) + val resolvedGroupByExprsOpt = pivot.groupByExprsOpt.map { expressions => + expressions.map { expression => + expressionResolver + .resolveExpressionTreeInOperator( + parentOperator = pivot, + unresolvedExpression = expression + ) + .asInstanceOf[NamedExpression] + } + } + val resolvedPivotColumn = expressionResolver.resolveExpressionTreeInOperator( + parentOperator = pivot, + unresolvedExpression = pivot.pivotColumn + ) + val resolvedPivotValues = pivot.pivotValues.map { expression => + expressionResolver.resolveExpressionTreeInOperator( + parentOperator = pivot, + unresolvedExpression = expression + ) + } + val resolvedAggregates = expressionResolver.resolvePivotAggregates(pivot) + + checkUnorderablePivotColError(resolvedPivotColumn) + + val transformedPivot = PivotTransformer( + child = resolvedChild, + pivotValues = resolvedPivotValues, + pivotColumn = resolvedPivotColumn, + groupByExpressionsOpt = resolvedGroupByExprsOpt, + aggregates = resolvedAggregates, + childOutput = scopes.current.output, + newAlias = (child, name) => autoGeneratedAliasProvider.newAlias(child, name) + ) + + handleResolvedOperator(transformedPivot) + } + + val newOutput = computeOutput(resolvedPivot) + + scopes.overwriteOutputAndExtendHiddenOutput(output = newOutput) + + resolvedPivot + } + + /** + * [[PivotTransformer]] may introduce new expressions (such as [[Cast]], [[If]] etc) that need + * to be resolved again. This method resolves those expressions in the given operator. + */ + private def handleResolvedOperator(operator: LogicalPlan): LogicalPlan = { + operator match { + case project: Project => + val resolvedProjectList = project.projectList.map { expression => + expressionResolver + .resolveExpressionTreeInOperator( + parentOperator = project, + unresolvedExpression = expression + ) + .asInstanceOf[NamedExpression] + } + project.copy(projectList = resolvedProjectList) + case aggregate: Aggregate => + val resolvedAggregateExpressions = aggregate.aggregateExpressions.map { expression => + expressionResolver + .resolveExpressionTreeInOperator( + parentOperator = aggregate, + unresolvedExpression = expression + ) + .asInstanceOf[NamedExpression] + } + aggregate.copy(aggregateExpressions = resolvedAggregateExpressions) + } + } + + private def computeOutput(operator: LogicalPlan): Seq[Attribute] = { + operator match { + case project: Project => + project.projectList.map(namedExpression => namedExpression.toAttribute) + case aggregate: Aggregate => + aggregate.aggregateExpressions.map(namedExpression => namedExpression.toAttribute) + } + } + + private def checkUnorderablePivotColError(pivotColumn: Expression): Unit = { + if (!RowOrdering.isOrderable(pivotColumn.dataType)) { + throw QueryCompilationErrors.unorderablePivotColError(pivotColumn) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/RecursiveCteState.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/RecursiveCteState.scala new file mode 100644 index 0000000000000..53ad8bdbca3c8 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/RecursiveCteState.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import org.apache.spark.sql.catalyst.expressions.Attribute + +/** + * Parameters specific to recursive CTEs. + * + * @param cteId The id of the recursive CTE. + * @param cteName The name of the recursive CTE. + * @param maxDepth The maximum depth of the recursive CTE. + */ +case class RecursiveCteParameters(cteId: Long, cteName: String, maxDepth: Option[Int]) + +/** + * Mutable state for recursive CTEs tracked during resolution. + * + * @param anchorOutput The output schema of the anchor branch, used for self-references. + * @param expectedUnionDepth The depth of the first UNION encountered in the recursive CTE. Used to + * ensure anchor registration and UnionLoop placement only occur at the correct depth, preventing + * nested UNIONs from interfering. + * @param columnNames Column names from UnresolvedSubqueryColumnAliases, if present. + * @param referencedRecursively Whether this CTE has been referenced from within itself. + */ +private[resolver] class RecursiveCteState { + var anchorOutput: Option[Seq[Attribute]] = None + var expectedUnionDepth: Option[Int] = None + var columnNames: Option[Seq[String]] = None + var referencedRecursively: Boolean = false +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/RepartitionByExpressionResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/RepartitionByExpressionResolver.scala new file mode 100644 index 0000000000000..01636f35f2c52 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/RepartitionByExpressionResolver.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import org.apache.spark.sql.catalyst.analysis.AnalysisErrorAt +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, RepartitionByExpression} +import org.apache.spark.sql.types.{DataType, VariantType} + +/** + * Resolver class that resolves [[RepartitionByExpression]] operators. + */ +class RepartitionByExpressionResolver(resolver: Resolver, expressionResolver: ExpressionResolver) + extends TreeNodeResolver[RepartitionByExpression, LogicalPlan] + with ResolvesNameByHiddenOutput { + + private val scopes = resolver.getNameScopes + private val operatorResolutionContextStack = resolver.getOperatorResolutionContextStack + + /** + * Resolves a [[RepartitionByExpression]] operator. This involves: + * 1. Resolving the child operator. + * 2. Resolving the partition expressions using the + * [[ExpressionResolve.resolveExpressionTreeInOperator]]. + * 3. In case the initial expressions list is non-empty, collect all the attributes referenced + * from the hidden output. This can happen in the following query: + * + * {{{ + * SELECT col1 FROM values(1, 2) WHERE col2 = 1 DISTRIBUTE BY col2; + * }}} + * + * Parsed plan looks like: + * + * {{{ + * 'RepartitionByExpression ['col2] + * +- 'Project ['col1] + * +- 'Filter ('col2 = 1) + * +- LocalRelation [col1#0, col2#1] + * }}} + * + * Here `col2` from the `DISTRIBUTE BY` would be resolved using the hidden output and thus + * the analyzed plan looks like: + * + * {{{ + * Project [col1#0] + * +- RepartitionByExpression [col2#1] + * +- Project [col1#0, col2#1] + * +- Filter (col2#1 = 1) + * +- LocalRelation [col1#0, col2#1] + * }}} + * 4. Inserting any missing attributes from the hidden output into the child operator. + * 5. Retaining the original output of the operator (see [[ResolvesNameByHiddenOutput]] scala + * doc for more info). + */ + override def resolve(repartitionByExpression: RepartitionByExpression): LogicalPlan = { + val resolvedChild = resolver.resolve(repartitionByExpression.child) + + scopes.current.markScopeForHiddenOutputResolution() + + val resolvedExpressions = repartitionByExpression.expressions.map { expression => + expressionResolver.resolveExpressionTreeInOperator( + unresolvedExpression = expression, + parentOperator = repartitionByExpression + ) + } + + val missingAttributes: Seq[Attribute] = + if (repartitionByExpression.partitionExpressions.isEmpty) { + Seq.empty + } else { + scopes.current.resolveMissingAttributesByHiddenOutput( + expressionResolver.getLastReferencedAttributes + ) + } + + val resolvedChildWithMissingAttributes = + insertMissingExpressions(resolvedChild, missingAttributes) + + val resolvedRepartitionByExpression = repartitionByExpression.copy( + partitionExpressions = resolvedExpressions, + child = resolvedChildWithMissingAttributes + ) + + validateRepartitionByVariant(resolvedRepartitionByExpression) + + if (!resolvedChildWithMissingAttributes.eq(resolvedChild)) { + retainOriginalOutput( + operator = resolvedRepartitionByExpression, + missingExpressions = missingAttributes, + scopes = scopes, + operatorResolutionContextStack = operatorResolutionContextStack + ) + } else { + resolvedRepartitionByExpression + } + } + + private def validateRepartitionByVariant( + repartitionByExpression: RepartitionByExpression): Unit = { + val variantExpressionInPartitionExpression = + repartitionByExpression.partitionExpressions.find(e => hasVariantType(e.dataType)) + + if (variantExpressionInPartitionExpression.isDefined) { + val variantExpr = variantExpressionInPartitionExpression.get + repartitionByExpression.failAnalysis( + errorClass = "UNSUPPORTED_FEATURE.PARTITION_BY_VARIANT", + messageParameters = + Map("expr" -> toSQLExpr(variantExpr), "dataType" -> toSQLType(variantExpr.dataType)) + ) + } + } + + private def hasVariantType(dt: DataType): Boolean = { + dt.existsRecursively(_.isInstanceOf[VariantType]) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverRunnerResult.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverRunnerResult.scala new file mode 100644 index 0000000000000..7c8305053d831 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverRunnerResult.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +/** + * The result of [[ResolverRunner]] operation. Can be one of the following: + * - [[ResolverRunnerResultResolvedPlan]]: The plan was properly resolved by the single-pass + * Analyzer. + * - [[ResolverRunnerResultPlanNotSupported]]: The plan is not supported and cannot be processed + * by the single-pass Analyzer. We would fall back to the fixed-point Analyzer for this plan. + * - [[ResolverRunnerResultUnrecoverableException]]: An unrecoverable exception occurred during + * the single-pass Analysis, which prevents to use the fixed-point Analyzer as a fallback. We + * should just propagate it to the user from the [[HybridAnalyzer]]. + */ +sealed trait ResolverRunnerResult + +case class ResolverRunnerResultResolvedPlan(plan: LogicalPlan) extends ResolverRunnerResult + +case class ResolverRunnerResultPlanNotSupported(reason: String) extends ResolverRunnerResult + +case class ResolverRunnerResultUnrecoverableException(exception: Throwable) + extends ResolverRunnerResult diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TableValuedFunctionResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TableValuedFunctionResolver.scala new file mode 100644 index 0000000000000..913319df1e1e4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TableValuedFunctionResolver.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import org.apache.spark.sql.catalyst.analysis.{FunctionResolution, UnresolvedTableValuedFunction} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +/** + * Resolver class that resolves [[UnresolvedTableValuedFunction]] operators. + */ +class TableValuedFunctionResolver( + resolver: Resolver, + expressionResolver: ExpressionResolver, + functionResolution: FunctionResolution) { + + /** + * Resolves [[UnresolvedTableValuedFunction]] operator by replacing it with an arbitrary + * [[LogicalPlan]]. + * + * The resolution happens in several steps: + * - Resolve all the function arguments using [[ExpressionResolver]]. + * - Resolve the function itself using [[FunctionResolution]] getting a [[LogicalPlan]]. + * - Recursively resolve the resulting [[LogicalPlan]]. + * + * Examples: + * - `range(...)` resolves to [[Range]] operator. + * {{{ + * SELECT * FROM range(10) + * + * == Parsed Logical Plan == + * 'Project [*] + * +- 'UnresolvedTableValuedFunction [range], [10], false + * + * == Analyzed Logical Plan == + * id: bigint + * Project [id#0] + * +- Range (0, 10, step=1) + * }}} + * + * - `explode(...)` resolves to [[Generate]] operator with [[OneRowRelation]]. + * {{{ + * SELECT * FROM explode(array(1, 2, 3)) + * + * == Parsed Logical Plan == + * 'Project [*] + * +- 'UnresolvedTableValuedFunction [explode], ['array(1, 2, 3)], false + * + * == Analyzed Logical Plan == + * col: int + * Project [col#0] + * +- Generate explode(array(1, 2, 3)), false, [col#0] + * +- OneRowRelation + * }}} + */ + def resolve(unresolvedTVF: UnresolvedTableValuedFunction): LogicalPlan = { + val resolvedArguments = unresolvedTVF.functionArgs.map( + expressionResolver.resolveExpressionTreeInOperator(_, unresolvedTVF) + ) + val unresolvedTVFWithResolvedArguments = + unresolvedTVF.copy(functionArgs = resolvedArguments) + val resolvedTVF = + functionResolution.resolveTableValuedFunction(unresolvedTVFWithResolvedArguments) + resolver.resolve(resolvedTVF) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnpivotResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnpivotResolver.scala new file mode 100644 index 0000000000000..62298e74d0cca --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnpivotResolver.scala @@ -0,0 +1,278 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import java.util.HashSet + +import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.analysis.{UnpivotTransformer, UnpivotTypeCoercion} +import org.apache.spark.sql.catalyst.expressions.{ + AttributeReference, + Expression, + ExprId, + NamedExpression +} +import org.apache.spark.sql.catalyst.plans.logical.{Expand, LogicalPlan, Unpivot} +import org.apache.spark.sql.errors.QueryCompilationErrors + +/** + * Resolver class that resolves [[Unpivot]] operators. + */ +class UnpivotResolver(operatorResolver: Resolver, expressionResolver: ExpressionResolver) + extends TreeNodeResolver[Unpivot, LogicalPlan] { + private val scopes: NameScopeStack = operatorResolver.getNameScopes + private val expressionIdAssigner = expressionResolver.getExpressionIdAssigner + + /** + * Resolves an [[Unpivot]] operator. Do that using the following steps: + * 1. Resolve the child operator. + * 2. Resolve the ids and values expressions based on their presence (see respective scala docs + * for more info). + * 3. Perform type coercion on the values expressions. + * 4. Validate the lengths of the values expressions. + * 5. Use the [[UnpivotTransformer]] to create the resolved operator (see its scala doc for more + * info). + * 6. Map the output attributes to new ones with unique expression ids. This is needed since the + * [[UnpivotTransformer]] introduces new attributes (in the [[Expand.output]]). + * 7. Overwrite the current scope's output with the resolved operator's output. + * + * For example in the given query: + * + * {{{ + * SELECT * FROM VALUES (2024, 15000, 20000) AS sales(year, q1, q2) + * UNPIVOT (revenue FOR quarter IN (q1 AS `Q1`, q2 AS `Q2`)); + * }}} + * + * The unresolved plan would be: + * + * {{{ + * 'Project [*] + * +- 'Filter isnotnull(coalesce('revenue)) + * +- 'Unpivot List(List('q1), List('q2)), List(Some(Q1), Some(Q2)), quarter, [revenue] + * +- SubqueryAlias sales + * +- LocalRelation [year#0, q1#1, q2#2] + * }}} + * + * Whereas the resolved plan would be: + * + * {{{ + * Project [year#0, quarter#4, revenue#5] + * +- Filter isnotnull(coalesce(revenue#5)) + * +- Expand [[year#0, Q1, q1#1], [year#0, Q2, q2#2], [year#0, quarter#4, revenue#5] + * +- SubqueryAlias sales + * +- LocalRelation [year#0, q1#1, q2#2] + * }}} + */ + override def resolve(unpivot: Unpivot): LogicalPlan = { + val resolvedChild = operatorResolver.resolve(unpivot.child) + + val (resolvedIds, resolvedValues) = (unpivot.ids, unpivot.values) match { + case (Some(ids), Some(values)) => handleIdsDefinedValuesDefined(ids, values, unpivot) + case (Some(ids), None) => handleIdsDefinedValuesUndefined(ids, unpivot) + case (None, Some(values)) => handleIdsUndefinedValuesDefined(values, unpivot) + case (None, None) => handleIdsUndefinedValuesUndefined() + } + + val partiallyResolvedUnpivot = + unpivot.copy(ids = Some(resolvedIds), values = Some(resolvedValues)) + + val typeCoercedUnpivot = UnpivotTypeCoercion(partiallyResolvedUnpivot) + + validateValuesTypeCoercioned(typeCoercedUnpivot) + + validateValuesLength(typeCoercedUnpivot.values.get, typeCoercedUnpivot) + + val resolvedOperator: Expand = UnpivotTransformer( + ids = typeCoercedUnpivot.ids.get, + values = typeCoercedUnpivot.values.get, + aliases = typeCoercedUnpivot.aliases, + variableColumnName = typeCoercedUnpivot.variableColumnName, + valueColumnNames = typeCoercedUnpivot.valueColumnNames, + child = resolvedChild + ) + + val mappedOutput = resolvedOperator.output.map { attribute => + expressionIdAssigner.mapExpression(attribute, allowUpdatesForAttributeReferences = true) + } + + val finalOperator = resolvedOperator.copy(output = mappedOutput) + + scopes.overwriteOutputAndExtendHiddenOutput(output = mappedOutput) + + finalOperator + } + + /** + * If both ids and values are defined, resolve them both. Below is an example in SQL: + * + * {{{ + * SELECT * FROM VALUES (2024, 15000, 20000) AS sales(year, q1, q2) + * UNPIVOT (revenue FOR quarter IN (q1 AS `Q1`, q2 AS `Q2`)); + * }}} + */ + private def handleIdsDefinedValuesDefined( + ids: Seq[Expression], + values: Seq[Seq[Expression]], + unpivot: Unpivot): (Seq[NamedExpression], Seq[Seq[NamedExpression]]) = { + val resolvedIds = resolveIds(ids, unpivot) + val resolvedValues = resolveValues(values, unpivot) + (resolvedIds, resolvedValues) + } + + /** + * If ids are defined but values are not, resolve the ids and validate them. Then, infer the + * values by taking all output attributes from the current scope that are not part of the ids. + * Finally, resolve the inferred values. This can happen in DFs: + * {{{ + * df.unpivot( + * ids = Array($"id"), + * values = Array.empty, + * variableColumnName = "var", + * valueColumnName = "val" + * ) + * }}} + */ + private def handleIdsDefinedValuesUndefined( + ids: Seq[Expression], + unpivot: Unpivot): (Seq[NamedExpression], Seq[Seq[NamedExpression]]) = { + val resolvedIds = resolveIds(ids, unpivot) + + validateIds(resolvedIds) + + val resolvedIdExprIds = new HashSet[ExprId](resolvedIds.size) + resolvedIds.foreach { id => + resolvedIdExprIds.add(id.exprId) + } + val values = scopes.current.output.filterNot { attr => + resolvedIdExprIds.contains(attr.exprId) + } + + val expandedValues = values.map(Seq(_)) + + (resolvedIds, expandedValues) + } + + /** + * If values are defined but ids are not, resolve the values and validate them. Then, infer the + * ids by taking all output attributes from the current scope that are not part of the values. + * Finally, resolve the inferred ids. This can happen in DFs: + * {{{ + * df.unpivot( + * ids = Array.empty, + * values = Array($"str1", $"str2"), + * variableColumnName = "var", + * valueColumnName = "val" + * ) + * }}} + */ + private def handleIdsUndefinedValuesDefined( + values: Seq[Seq[Expression]], + unpivot: Unpivot): (Seq[NamedExpression], Seq[Seq[NamedExpression]]) = { + val resolvedValues = resolveValues(values, unpivot) + + validateValues(resolvedValues) + + val flattenedValues = resolvedValues.flatten + val resolvedValueExprIds = new HashSet[ExprId](flattenedValues.size) + flattenedValues.foreach { value => + resolvedValueExprIds.add(value.exprId) + } + val ids = scopes.current.output.filterNot { attr => + resolvedValueExprIds.contains(attr.exprId) + } + + (ids, resolvedValues) + } + + private def handleIdsUndefinedValuesUndefined(): Nothing = { + throw SparkException.internalError("Both UNPIVOT ids and values cannot be None") + } + + private def resolveIds(ids: Seq[Expression], unpivot: Unpivot): Seq[NamedExpression] = { + expressionResolver.resolveUnpivotArguments(ids, unpivot) + } + + /** + * Resolves each value group in `values` using `expressionResolver.resolveUnpivotArguments`. + * For each value group: + * - If a single expression expanded to multiple (star expansion), each resolved element + * becomes its own value group. Example: + * + * {{{ + * df.select($"str1", $"str2").unpivot( + * Array.empty, + * Array($"*"), + * variableColumnName = "var", + * valueColumnName = "val" + * ) + * }}} + * + * Here we would expand the star into two resolved expressions (`str1` and `str2`) and then + * separate them into their own value groups: `Seq(Seq(str1), Seq(str2))`. + * - Otherwise, all resolved elements stay together in one value group. + * The results are flattened into a single sequence of value groups at the end. + */ + private def resolveValues( + values: Seq[Seq[Expression]], + unpivot: Unpivot): Seq[Seq[NamedExpression]] = { + values.flatMap { valueGroup => + val resolved = expressionResolver.resolveUnpivotArguments(valueGroup, unpivot) + if (valueGroup.length == 1 && resolved.length > 1) { + resolved.map(Seq(_)) + } else { + Seq(resolved) + } + } + } + + private def validateIds(ids: Seq[NamedExpression]): Unit = { + val hasNonAttributeExpression = ids.exists { + case _: AttributeReference => false + case _ => true + } + if (hasNonAttributeExpression) { + throw QueryCompilationErrors.unpivotRequiresAttributes("id", "value", ids) + } + } + + private def validateValues(values: Seq[Seq[NamedExpression]]): Unit = { + val hasNonAttributeExpression = values.exists { valueGroup => + valueGroup.exists { + case _: AttributeReference => false + case _ => true + } + } + if (hasNonAttributeExpression) { + throw QueryCompilationErrors.unpivotRequiresAttributes("value", "id", values.flatten) + } + } + + private def validateValuesTypeCoercioned(typeCoercedUnpivot: Unpivot): Unit = { + if (!typeCoercedUnpivot.valuesTypeCoercioned) { + throw QueryCompilationErrors.unpivotValueDataTypeMismatchError( + typeCoercedUnpivot.values.get + ) + } + } + + private def validateValuesLength(values: Seq[Seq[NamedExpression]], unpivot: Unpivot): Unit = { + if (values.exists(_.length != unpivot.valueColumnNames.length)) { + throw QueryCompilationErrors.unpivotValueSizeMismatchError(unpivot.valueColumnNames.length) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/resolver/LogicalPlanDifferenceSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/resolver/LogicalPlanDifferenceSuite.scala new file mode 100644 index 0000000000000..d4c8688836510 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/resolver/LogicalPlanDifferenceSuite.scala @@ -0,0 +1,484 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import java.util.Locale + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.types.{IntegerType, StringType} + +class LogicalPlanDifferenceSuite extends SparkFunSuite with SQLConfHelper { + + private val idAttr = AttributeReference("id", IntegerType)() + private val nameAttr = AttributeReference("name", StringType)() + private val ageAttr = AttributeReference("age", IntegerType)() + + test("identical plans should return empty strings") { + val plan1 = LocalRelation(idAttr, nameAttr) + val plan2 = LocalRelation(idAttr, nameAttr) + + val (result1, result2) = LogicalPlanDifference(plan1, plan2, 2) + + assert(result1 == "") + assert(result2 == "") + } + + test("different plans with default context size (2 lines)") { + // Create larger plans with multiple operations + val plan1 = LocalRelation(idAttr, nameAttr, ageAttr) + .where(idAttr > 10) + .where(nameAttr === "Alice") + .where(ageAttr < 50) + .select(idAttr, nameAttr) + .limit(100) + + val plan2 = LocalRelation(idAttr, nameAttr, ageAttr) + .where(idAttr > 20) // Different condition - first mismatch + .where(nameAttr === "Alice") + .where(ageAttr < 50) + .select(idAttr, nameAttr) + .limit(100) + + val (result1, result2) = LogicalPlanDifference(plan1, plan2, 2) + + // Results should be truncated (shorter than full plans since mismatch is in the middle) + assert(result1.length < plan1.toString.length) + assert(result2.length < plan2.toString.length) + + // Should contain the different filter conditions + assert(result1.contains("10")) + assert(result2.contains("20")) + + // Should contain the mismatch line (Filter with id comparison) + assert(result1.contains("Filter")) + assert(result2.contains("Filter")) + + // Should contain the second filter (Alice) as context after the mismatch + assert(result1.contains("Alice")) + assert(result2.contains("Alice")) + + // Should contain context (age filter is also nearby) + assert(result1.contains("age") || result1.contains("50")) + assert(result2.contains("age") || result2.contains("50")) + + // Should NOT contain operations too far away (limit is at the top) + val lines1 = result1.split("\n").filter(_.nonEmpty).length + val lines2 = result2.split("\n").filter(_.nonEmpty).length + assert( + lines1 <= 7, + s"Expected at most 7 lines (2 before + mismatch + 2 after + margins), got $lines1" + ) + assert(lines2 <= 7, s"Expected at most 7 lines, got $lines2") + } + + test("plans differing at first line") { + // Larger plans where the top-level operation differs + val plan1 = LocalRelation(idAttr, nameAttr, ageAttr) + .where(idAttr > 5) + .where(nameAttr === "Bob") + .select(idAttr, nameAttr) // Different projection + .limit(50) + + val plan2 = LocalRelation(idAttr, nameAttr, ageAttr) + .where(idAttr > 5) + .where(nameAttr === "Bob") + .select(idAttr, ageAttr) // Different projection - includes age instead of name + .limit(50) + + val (result1, result2) = LogicalPlanDifference(plan1, plan2, 2) + + // Should show the first line (GlobalLimit) and context after + assert(result1.contains("GlobalLimit")) + assert(result2.contains("GlobalLimit")) + + // Should show the mismatch in the Project + assert(result1.contains("Project")) + assert(result2.contains("Project")) + assert(result1.contains("name")) + assert(result2.contains("age")) + + // Should show some context after (first Filter) + assert(result1.contains("Filter")) + assert(result2.contains("Filter")) + + // Should NOT show the bottom LocalRelation (too far) + assert(!result1.contains("LocalRelation")) + assert(!result2.contains("LocalRelation")) + } + + test("plans differing at last line") { + // Large plans where only the bottom LocalRelation differs + val plan1 = LocalRelation(idAttr, nameAttr) // Different attributes + .where(idAttr > 5) + .where(nameAttr === "Charlie") + .where(idAttr < 100) + .select(idAttr, nameAttr) + .limit(10) + + val plan2 = LocalRelation(idAttr, ageAttr) // Different attributes - uses age instead of name + .where(idAttr > 5) + .where(nameAttr === "Charlie") + .where(idAttr < 100) + .select(idAttr, nameAttr) + .limit(10) + + val (result1, result2) = LogicalPlanDifference(plan1, plan2, 2) + + // The plans are the same except LocalRelation has different attributes at the bottom + // However, the FIRST line that differs will be much earlier in the tree + // because the filter references different attributes + + // Should show some Filter operations (the first mismatch will be in one of them) + assert(result1.contains("Filter")) + assert(result2.contains("Filter")) + + // Both should have truncated output + val lines1 = result1.split("\n").filter(_.nonEmpty).length + val lines2 = result2.split("\n").filter(_.nonEmpty).length + assert(lines1 >= 1 && lines1 <= 7, s"Expected 1-7 lines, got $lines1") + assert(lines2 >= 1 && lines2 <= 7, s"Expected 1-7 lines, got $lines2") + + // At least one should reference the different attributes or values + val combined1 = result1.toLowerCase(Locale.ROOT) + val combined2 = result2.toLowerCase(Locale.ROOT) + assert(combined1.contains("name") || combined1.contains("age") || combined1.contains("id")) + assert(combined2.contains("name") || combined2.contains("age") || combined2.contains("id")) + } + + test("custom context size - 0 lines (only mismatch line)") { + // Larger plan with mismatch in the middle + val plan1 = LocalRelation(idAttr, nameAttr, ageAttr) + .where(idAttr > 10) // First mismatch + .where(nameAttr === "David") + .where(ageAttr < 60) + .select(idAttr, nameAttr) + .limit(25) + + val plan2 = LocalRelation(idAttr, nameAttr, ageAttr) + .where(idAttr > 99) // Different value + .where(nameAttr === "David") + .where(ageAttr < 60) + .select(idAttr, nameAttr) + .limit(25) + + val (result1, result2) = LogicalPlanDifference(plan1, plan2, 0) + + // With 0 context lines, should only show the mismatched line + val lines1 = result1.split("\n").filter(_.nonEmpty) + val lines2 = result2.split("\n").filter(_.nonEmpty) + + // Should have exactly 1 line (the mismatch) + assert(lines1.length == 1, s"Expected 1 line, got ${lines1.length}: ${lines1.mkString("; ")}") + assert(lines2.length == 1, s"Expected 1 line, got ${lines2.length}: ${lines2.mkString("; ")}") + + // Should contain the filter with the different values + assert(result1.contains("Filter") && result1.contains("10")) + assert(result2.contains("Filter") && result2.contains("99")) + + // Should NOT contain other parts of the plan + assert(!result1.contains("David")) + assert(!result2.contains("David")) + assert(!result1.contains("GlobalLimit")) + assert(!result2.contains("GlobalLimit")) + } + + test("custom context size - 5 lines") { + // Very large plan to demonstrate larger context window + val plan1 = LocalRelation(idAttr, nameAttr, ageAttr) + .where(idAttr > 0) + .where(nameAttr =!= "") + .where(ageAttr > 0) + .where(ageAttr < 120) + .where(idAttr < 1000000) + .where(nameAttr === "Eve") // Mismatch is here (7th line from top) + .where(ageAttr > 18) + .where(idAttr % 2 === 0) + .select(idAttr, nameAttr, ageAttr) + .limit(500) + + val plan2 = LocalRelation(idAttr, nameAttr, ageAttr) + .where(idAttr > 0) + .where(nameAttr =!= "") + .where(ageAttr > 0) + .where(ageAttr < 120) + .where(idAttr < 1000000) + .where(nameAttr === "Frank") // Different name - mismatch + .where(ageAttr > 18) + .where(idAttr % 2 === 0) + .select(idAttr, nameAttr, ageAttr) + .limit(500) + + val (result1, result2) = LogicalPlanDifference(plan1, plan2, 5) + + // With 5 context lines, should show 5 before, mismatch, and 5 after (11 total) + val lines1 = result1.split("\n").filter(_.nonEmpty) + val lines2 = result2.split("\n").filter(_.nonEmpty) + + assert(lines1.length >= 10, s"Expected at least 10 lines, got ${lines1.length}") + assert(lines2.length >= 10, s"Expected at least 10 lines, got ${lines2.length}") + + // Should contain the mismatch + assert(result1.contains("Eve")) + assert(result2.contains("Frank")) + + // Should contain 5 lines of context before (includes filters with age and id checks) + assert(result1.contains("120")) // age < 120 filter + assert(result2.contains("120")) + + // Should contain 5 lines of context after (includes age > 18 and modulo filters) + assert(result1.contains("18")) // age > 18 filter + assert(result2.contains("18")) + assert(result1.contains("% 2")) // modulo check + assert(result2.contains("% 2")) + + // Should still be truncated (not showing LocalRelation at bottom or GlobalLimit at top) + assert(!result1.contains("LocalRelation")) + assert(!result2.contains("LocalRelation")) + } + + test("plans with different number of lines") { + // Plan 1 is short (3 operations) + val plan1 = LocalRelation(idAttr) + .where(idAttr > 100) + .select(idAttr) + + // Plan 2 is much longer (8 operations) + val plan2 = LocalRelation(idAttr, nameAttr, ageAttr) + .where(idAttr > 0) + .where(nameAttr === "George") + .where(ageAttr > 21) + .where(idAttr < 1000) + .where(nameAttr =!= "") + .select(idAttr, nameAttr, ageAttr) + .limit(100) + + val (result1, result2) = LogicalPlanDifference(plan1, plan2, 2) + + // The mismatch is at the top (Project vs GlobalLimit), so both should show their top portions + val lines1 = result1.split("\n").filter(_.nonEmpty) + val lines2 = result2.split("\n").filter(_.nonEmpty) + + assert(lines1.length >= 1, s"Plan1 should have at least 1 line, got ${lines1.length}") + assert(lines2.length >= 1, s"Plan2 should have at least 1 line, got ${lines2.length}") + + // Both should show their top-level operations + assert(result1.contains("Project")) + assert(result2.contains("GlobalLimit") || result2.contains("LocalLimit")) + + // Should show some context + assert(result1.contains("Filter") && result1.contains("100")) + assert(result2.contains("Project") || result2.contains("Filter")) + } + + test("plans where second plan is shorter") { + // Plan 1 is longer (7 operations) + val plan1 = LocalRelation(idAttr, nameAttr, ageAttr) + .where(idAttr > 0) + .where(nameAttr === "Helen") + .where(ageAttr < 80) + .where(idAttr < 500) + .select(idAttr, nameAttr) + .limit(50) + + // Plan 2 is shorter (3 operations) + val plan2 = LocalRelation(idAttr) + .where(idAttr > 5) + .select(idAttr) + + val (result1, result2) = LogicalPlanDifference(plan1, plan2, 2) + + val lines1 = result1.split("\n").filter(_.nonEmpty) + val lines2 = result2.split("\n").filter(_.nonEmpty) + + // Both should have content + assert(lines1.length >= 1, s"Plan1 should have at least 1 line") + assert(lines2.length >= 1, s"Plan2 should have at least 1 line") + + // Should show the top-level difference + assert(result1.contains("GlobalLimit") || result1.contains("LocalLimit")) + assert(result2.contains("Project")) + + // Plan1 should show some of its filters as context + assert(result1.contains("Project") || result1.contains("Filter")) + + // Plan2 should show its simpler structure + assert(result2.contains("Filter") && result2.contains("5")) + } + + test("complex nested plans") { + // Create larger subqueries with multiple operations + val subquery1 = LocalRelation(idAttr, ageAttr) + .where(idAttr > 5) + .where(ageAttr > 18) + .select(idAttr, ageAttr) + + val subquery2 = LocalRelation(idAttr, ageAttr) + .where(idAttr > 10) // Different threshold - mismatch + .where(ageAttr > 18) + .select(idAttr, ageAttr) + + val plan1 = LocalRelation(nameAttr) + .where(nameAttr =!= "") + .join(subquery1) + .where(idAttr < 1000) + .select(nameAttr, idAttr, ageAttr) + .limit(200) + + val plan2 = LocalRelation(nameAttr) + .where(nameAttr =!= "") + .join(subquery2) + .where(idAttr < 1000) + .select(nameAttr, idAttr, ageAttr) + .limit(200) + + val (result1, result2) = LogicalPlanDifference(plan1, plan2, 2) + + // Should show the mismatch which is in the subquery filter + assert(result1.contains("5")) + assert(result2.contains("10")) + + // Should show filter operations around the mismatch + assert(result1.contains("Filter")) + assert(result2.contains("Filter")) + + // Verify the outputs are different + assert(result1 != result2, "Plans should produce different output strings") + + // Should be truncated (not showing all operations) + val lines1 = result1.split("\n").filter(_.nonEmpty).length + val lines2 = result2.split("\n").filter(_.nonEmpty).length + assert(lines1 <= 10, s"Expected truncated output with at most 10 lines, got $lines1") + assert(lines2 <= 10, s"Expected truncated output with at most 10 lines, got $lines2") + assert(lines1 >= 1, "Should have at least the mismatch line") + assert(lines2 >= 1, "Should have at least the mismatch line") + } + + test("plans with special characters in string representation") { + val plan1 = LocalRelation(nameAttr) + .where(nameAttr === "test\nwith\nnewlines") + + val plan2 = LocalRelation(nameAttr) + .where(nameAttr === "different\nvalue") + + val (result1, result2) = LogicalPlanDifference(plan1, plan2, 2) + + assert(result1.nonEmpty) + assert(result2.nonEmpty) + } + + test("empty plans") { + val plan1 = LocalRelation(Seq.empty) + val plan2 = LocalRelation(Seq.empty) + + val (result1, result2) = LogicalPlanDifference(plan1, plan2, 2) + + // Empty plans should be equal, so return empty strings + assert(result1 == "") + assert(result2 == "") + } + + test("very large context size should not exceed plan length") { + val plan1 = LocalRelation(idAttr) + .where(idAttr > 10) + .select(idAttr) + + val plan2 = LocalRelation(idAttr) + .where(idAttr > 20) + .select(idAttr) + + val (result1, result2) = LogicalPlanDifference(plan1, plan2, 1000) + + // With very large context, should show entire plans + val lines1 = result1.split("\n").filter(_.nonEmpty).length + val lines2 = result2.split("\n").filter(_.nonEmpty).length + val planLines1 = plan1.toString.split("\n").filter(_.nonEmpty).length + val planLines2 = plan2.toString.split("\n").filter(_.nonEmpty).length + + assert(lines1 <= planLines1) + assert(lines2 <= planLines2) + } + + test("context size validation - negative value") { + // The config has a check that should prevent negative values + // This test verifies that the validation is in place + intercept[IllegalArgumentException] { + withSQLConf( + "spark.sql.analyzer.singlePassResolver.logicalPlanDiffContextSize" -> "-1" + ) { + // This should fail due to the checkValue in the config + } + } + } + + test("mismatch at boundary with minimal context") { + // Larger plan with mismatch in the middle + val plan1 = LocalRelation(idAttr, nameAttr, ageAttr) + .where(idAttr > 0) + .where(nameAttr =!= "") + .where(ageAttr > 10) // Mismatch here + .where(idAttr < 10000) + .select(idAttr, nameAttr) + .limit(75) + + val plan2 = LocalRelation(idAttr, nameAttr, ageAttr) + .where(idAttr > 0) + .where(nameAttr =!= "") + .where(ageAttr > 25) // Different value + .where(idAttr < 10000) + .select(idAttr, nameAttr) + .limit(75) + + val (result1, result2) = LogicalPlanDifference(plan1, plan2, 1) + + // With 1 context line, should show 1 line before, the mismatch, and 1 line after (3 total) + val lines1 = result1.split("\n").filter(_.nonEmpty) + val lines2 = result2.split("\n").filter(_.nonEmpty) + + assert( + lines1.length <= 3, + s"Expected at most 3 lines, got ${lines1.length}: ${lines1.mkString("; ")}" + ) + assert( + lines2.length <= 3, + s"Expected at most 3 lines, got ${lines2.length}: ${lines2.mkString("; ")}" + ) + + // Should contain the mismatch (age filter) + assert(result1.contains("Filter") && result1.contains("10")) + assert(result2.contains("Filter") && result2.contains("25")) + + // Should contain 1 line of context before (name filter) + assert(result1.contains("name") || result1.contains("\"\"")) + assert(result2.contains("name") || result2.contains("\"\"")) + + // Should contain 1 line of context after (id filter) + assert(result1.contains("10000")) + assert(result2.contains("10000")) + + // Should NOT contain operations too far away + assert(!result1.contains("GlobalLimit")) + assert(!result2.contains("GlobalLimit")) + assert(!result1.contains("LocalRelation")) + assert(!result2.contains("LocalRelation")) + } +} From d16c24634278fe9042d41066b57c697e1626d2be Mon Sep 17 00:00:00 2001 From: Mihailo Timotic Date: Wed, 18 Mar 2026 13:31:05 +0000 Subject: [PATCH 2/3] Remove negative config validation test from LogicalPlanDifferenceSuite The test expected IllegalArgumentException for a negative config value, but the config key doesn't have that validation in OSS (it was a DatabricksSQLConf check). Co-authored-by: Isaac --- .../resolver/LogicalPlanDifferenceSuite.scala | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/resolver/LogicalPlanDifferenceSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/resolver/LogicalPlanDifferenceSuite.scala index d4c8688836510..3b513c4b564b9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/resolver/LogicalPlanDifferenceSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/resolver/LogicalPlanDifferenceSuite.scala @@ -418,18 +418,6 @@ class LogicalPlanDifferenceSuite extends SparkFunSuite with SQLConfHelper { assert(lines2 <= planLines2) } - test("context size validation - negative value") { - // The config has a check that should prevent negative values - // This test verifies that the validation is in place - intercept[IllegalArgumentException] { - withSQLConf( - "spark.sql.analyzer.singlePassResolver.logicalPlanDiffContextSize" -> "-1" - ) { - // This should fail due to the checkValue in the config - } - } - } - test("mismatch at boundary with minimal context") { // Larger plan with mismatch in the middle val plan1 = LocalRelation(idAttr, nameAttr, ageAttr) From cb41322a9533acbfbd31567cc44b2f35cf0a4c73 Mon Sep 17 00:00:00 2001 From: Mihailo Timotic Date: Wed, 18 Mar 2026 16:49:33 +0000 Subject: [PATCH 3/3] Address PR review comments - GroupingAnalyticsResolver: simplify by negating condition and throwing early - LogicalPlanDifference: extract appendLine helper to dedup buffer iteration - NameParameterizedQueryResolver: use multi-line string for error message - LogicalPlanDifferenceSuite: extract repeated string literals into helper vals Co-authored-by: Isaac --- .../resolver/GroupingAnalyticsResolver.scala | 9 ++- .../resolver/LogicalPlanDifference.scala | 19 ++--- .../NameParameterizedQueryResolver.scala | 5 +- .../resolver/LogicalPlanDifferenceSuite.scala | 74 ++++++++++--------- 4 files changed, 56 insertions(+), 51 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/GroupingAnalyticsResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/GroupingAnalyticsResolver.scala index 6ba98677ab2c8..af92e2cb453ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/GroupingAnalyticsResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/GroupingAnalyticsResolver.scala @@ -119,13 +119,14 @@ class GroupingAnalyticsResolver(resolver: Resolver, expressionResolver: Expressi def handleSortOrderExpressionsWithGroupingAnalytics( orderExpressions: Seq[SortOrder], sort: Sort): Seq[SortOrder] = { - val groupingExpressions = if (scopes.current.baseAggregate.isDefined) { + if (!scopes.current.baseAggregate.isDefined) { + throw QueryCompilationErrors.groupingMustWithGroupingSetsOrCubeOrRollupError() + } + + val groupingExpressions = GroupingAnalyticsTransformer.collectGroupingExpressions( scopes.current.baseAggregate.get ) - } else { - throw QueryCompilationErrors.groupingMustWithGroupingSetsOrCubeOrRollupError() - } val groupingId = VirtualColumn.groupingIdAttribute val resolvedGroupingId = expressionResolver diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LogicalPlanDifference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LogicalPlanDifference.scala index c08977f9efb53..7183bdb6469d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LogicalPlanDifference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LogicalPlanDifference.scala @@ -75,30 +75,27 @@ object LogicalPlanDifference { ) } + def appendLine(result: StringBuilder, planString: String, position: LinePosition): Unit = { + result.append(planString, position.start, position.end).append('\n') + } + if (!mismatchFound) { ("", "") } else { val lhsResult = new StringBuilder() val rhsResult = new StringBuilder() - lhsBuffer.foreach { position => - lhsResult.append(lhsPlanString, position.start, position.end).append('\n') - } - rhsBuffer.foreach { position => - rhsResult.append(rhsPlanString, position.start, position.end).append('\n') - } + lhsBuffer.foreach(appendLine(lhsResult, lhsPlanString, _)) + rhsBuffer.foreach(appendLine(rhsResult, rhsPlanString, _)) var linesAfter = 0 while (linesAfter < contextSize && (lhsIterator.hasNext || rhsIterator.hasNext)) { if (lhsIterator.hasNext) { - val position = lhsIterator.next() - lhsResult.append(lhsPlanString, position.start, position.end).append('\n') + appendLine(lhsResult, lhsPlanString, lhsIterator.next()) } if (rhsIterator.hasNext) { - val position = rhsIterator.next() - rhsResult.append(rhsPlanString, position.start, position.end).append('\n') + appendLine(rhsResult, rhsPlanString, rhsIterator.next()) } - linesAfter += 1 } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/NameParameterizedQueryResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/NameParameterizedQueryResolver.scala index 861d0aa744a18..5e2b3bd7873b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/NameParameterizedQueryResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/NameParameterizedQueryResolver.scala @@ -158,8 +158,9 @@ class NameParameterizedQueryResolver( parameterValues: Seq[Expression]): Unit = { if (parameterNames.length != parameterValues.length) { throw SparkException.internalError( - s"The number of argument names ${parameterNames.length} " + - s"must be equal to the number of argument values ${parameterValues.length}." + s"""The number of argument names ${parameterNames.length} + |must be equal to the number of argument values ${parameterValues.length}.""" + .stripMargin ) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/resolver/LogicalPlanDifferenceSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/resolver/LogicalPlanDifferenceSuite.scala index 3b513c4b564b9..36e739969b40f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/resolver/LogicalPlanDifferenceSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/resolver/LogicalPlanDifferenceSuite.scala @@ -33,6 +33,12 @@ class LogicalPlanDifferenceSuite extends SparkFunSuite with SQLConfHelper { private val nameAttr = AttributeReference("name", StringType)() private val ageAttr = AttributeReference("age", IntegerType)() + private val GLOBAL_LIMIT = "GlobalLimit" + private val LOCAL_LIMIT = "LocalLimit" + private val LOCAL_RELATION = "LocalRelation" + private val FILTER = "Filter" + private val PROJECT = "Project" + test("identical plans should return empty strings") { val plan1 = LocalRelation(idAttr, nameAttr) val plan2 = LocalRelation(idAttr, nameAttr) @@ -70,8 +76,8 @@ class LogicalPlanDifferenceSuite extends SparkFunSuite with SQLConfHelper { assert(result2.contains("20")) // Should contain the mismatch line (Filter with id comparison) - assert(result1.contains("Filter")) - assert(result2.contains("Filter")) + assert(result1.contains(FILTER)) + assert(result2.contains(FILTER)) // Should contain the second filter (Alice) as context after the mismatch assert(result1.contains("Alice")) @@ -108,22 +114,22 @@ class LogicalPlanDifferenceSuite extends SparkFunSuite with SQLConfHelper { val (result1, result2) = LogicalPlanDifference(plan1, plan2, 2) // Should show the first line (GlobalLimit) and context after - assert(result1.contains("GlobalLimit")) - assert(result2.contains("GlobalLimit")) + assert(result1.contains(GLOBAL_LIMIT)) + assert(result2.contains(GLOBAL_LIMIT)) // Should show the mismatch in the Project - assert(result1.contains("Project")) - assert(result2.contains("Project")) + assert(result1.contains(PROJECT)) + assert(result2.contains(PROJECT)) assert(result1.contains("name")) assert(result2.contains("age")) // Should show some context after (first Filter) - assert(result1.contains("Filter")) - assert(result2.contains("Filter")) + assert(result1.contains(FILTER)) + assert(result2.contains(FILTER)) // Should NOT show the bottom LocalRelation (too far) - assert(!result1.contains("LocalRelation")) - assert(!result2.contains("LocalRelation")) + assert(!result1.contains(LOCAL_RELATION)) + assert(!result2.contains(LOCAL_RELATION)) } test("plans differing at last line") { @@ -149,8 +155,8 @@ class LogicalPlanDifferenceSuite extends SparkFunSuite with SQLConfHelper { // because the filter references different attributes // Should show some Filter operations (the first mismatch will be in one of them) - assert(result1.contains("Filter")) - assert(result2.contains("Filter")) + assert(result1.contains(FILTER)) + assert(result2.contains(FILTER)) // Both should have truncated output val lines1 = result1.split("\n").filter(_.nonEmpty).length @@ -192,14 +198,14 @@ class LogicalPlanDifferenceSuite extends SparkFunSuite with SQLConfHelper { assert(lines2.length == 1, s"Expected 1 line, got ${lines2.length}: ${lines2.mkString("; ")}") // Should contain the filter with the different values - assert(result1.contains("Filter") && result1.contains("10")) - assert(result2.contains("Filter") && result2.contains("99")) + assert(result1.contains(FILTER) && result1.contains("10")) + assert(result2.contains(FILTER) && result2.contains("99")) // Should NOT contain other parts of the plan assert(!result1.contains("David")) assert(!result2.contains("David")) - assert(!result1.contains("GlobalLimit")) - assert(!result2.contains("GlobalLimit")) + assert(!result1.contains(GLOBAL_LIMIT)) + assert(!result2.contains(GLOBAL_LIMIT)) } test("custom context size - 5 lines") { @@ -252,8 +258,8 @@ class LogicalPlanDifferenceSuite extends SparkFunSuite with SQLConfHelper { assert(result2.contains("% 2")) // Should still be truncated (not showing LocalRelation at bottom or GlobalLimit at top) - assert(!result1.contains("LocalRelation")) - assert(!result2.contains("LocalRelation")) + assert(!result1.contains(LOCAL_RELATION)) + assert(!result2.contains(LOCAL_RELATION)) } test("plans with different number of lines") { @@ -282,12 +288,12 @@ class LogicalPlanDifferenceSuite extends SparkFunSuite with SQLConfHelper { assert(lines2.length >= 1, s"Plan2 should have at least 1 line, got ${lines2.length}") // Both should show their top-level operations - assert(result1.contains("Project")) - assert(result2.contains("GlobalLimit") || result2.contains("LocalLimit")) + assert(result1.contains(PROJECT)) + assert(result2.contains(GLOBAL_LIMIT) || result2.contains(LOCAL_LIMIT)) // Should show some context - assert(result1.contains("Filter") && result1.contains("100")) - assert(result2.contains("Project") || result2.contains("Filter")) + assert(result1.contains(FILTER) && result1.contains("100")) + assert(result2.contains(PROJECT) || result2.contains(FILTER)) } test("plans where second plan is shorter") { @@ -315,14 +321,14 @@ class LogicalPlanDifferenceSuite extends SparkFunSuite with SQLConfHelper { assert(lines2.length >= 1, s"Plan2 should have at least 1 line") // Should show the top-level difference - assert(result1.contains("GlobalLimit") || result1.contains("LocalLimit")) - assert(result2.contains("Project")) + assert(result1.contains(GLOBAL_LIMIT) || result1.contains(LOCAL_LIMIT)) + assert(result2.contains(PROJECT)) // Plan1 should show some of its filters as context - assert(result1.contains("Project") || result1.contains("Filter")) + assert(result1.contains(PROJECT) || result1.contains(FILTER)) // Plan2 should show its simpler structure - assert(result2.contains("Filter") && result2.contains("5")) + assert(result2.contains(FILTER) && result2.contains("5")) } test("complex nested plans") { @@ -358,8 +364,8 @@ class LogicalPlanDifferenceSuite extends SparkFunSuite with SQLConfHelper { assert(result2.contains("10")) // Should show filter operations around the mismatch - assert(result1.contains("Filter")) - assert(result2.contains("Filter")) + assert(result1.contains(FILTER)) + assert(result2.contains(FILTER)) // Verify the outputs are different assert(result1 != result2, "Plans should produce different output strings") @@ -452,8 +458,8 @@ class LogicalPlanDifferenceSuite extends SparkFunSuite with SQLConfHelper { ) // Should contain the mismatch (age filter) - assert(result1.contains("Filter") && result1.contains("10")) - assert(result2.contains("Filter") && result2.contains("25")) + assert(result1.contains(FILTER) && result1.contains("10")) + assert(result2.contains(FILTER) && result2.contains("25")) // Should contain 1 line of context before (name filter) assert(result1.contains("name") || result1.contains("\"\"")) @@ -464,9 +470,9 @@ class LogicalPlanDifferenceSuite extends SparkFunSuite with SQLConfHelper { assert(result2.contains("10000")) // Should NOT contain operations too far away - assert(!result1.contains("GlobalLimit")) - assert(!result2.contains("GlobalLimit")) - assert(!result1.contains("LocalRelation")) - assert(!result2.contains("LocalRelation")) + assert(!result1.contains(GLOBAL_LIMIT)) + assert(!result2.contains(GLOBAL_LIMIT)) + assert(!result1.contains(LOCAL_RELATION)) + assert(!result2.contains(LOCAL_RELATION)) } }