diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDDLCommandStringTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDDLCommandStringTypes.scala index 09ba36f99a270..a3798955ea546 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDDLCommandStringTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDDLCommandStringTypes.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal} +import org.apache.spark.sql.catalyst.expressions.{Cast, DefaultStringProducingExpression, Expression, Literal} import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumns, AlterColumnSpec, AlterTableCommand, AlterViewAs, ColumnDefinition, CreateTable, CreateView, LogicalPlan, QualifiedColType, ReplaceColumns, V2CreateTablePlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.TableCatalog @@ -100,11 +100,13 @@ object ResolveDDLCommandStringTypes extends Rule[LogicalPlan] { * new type instead of the default string type. */ private def transformPlan(plan: LogicalPlan, newType: StringType): LogicalPlan = { - plan resolveExpressionsUp { expression => + val transformedPlan = plan resolveExpressionsUp { expression => transformExpression .andThen(_.apply(newType)) .applyOrElse(expression, identity[Expression]) } + + castDefaultStringExpressions(transformedPlan, newType) } /** @@ -121,6 +123,30 @@ object ResolveDDLCommandStringTypes extends Rule[LogicalPlan] { newType => Literal(value, replaceDefaultStringType(dt, newType)) } + /** + * Casts [[DefaultStringProducingExpression]] in the plan to the `newType`. + */ + private def castDefaultStringExpressions(plan: LogicalPlan, newType: StringType): LogicalPlan = { + if (newType == StringType) return plan + + def inner(ex: Expression): Expression = ex match { + // Skip if we already added a cast in the previous pass. + case cast @ Cast(e: DefaultStringProducingExpression, dt, _, _) if newType == dt => + cast.copy(child = e.withNewChildren(e.children.map(inner))) + + // Add cast on top of [[DefaultStringProducingExpression]]. + case e: DefaultStringProducingExpression => + Cast(e.withNewChildren(e.children.map(inner)), newType) + + case other => + other.withNewChildren(other.children.map(inner)) + } + + plan resolveOperators { operator => + operator.mapExpressions(inner) + } + } + private def hasDefaultStringType(dataType: DataType): Boolean = dataType.existsRecursively(isDefaultStringType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala index 9e1968022c744..9679ecb554382 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala @@ -25,6 +25,13 @@ import org.apache.spark.sql.types.StringType abstract class DefaultCollationTestSuite extends QueryTest with SharedSparkSession { + val defaultStringProducingExpressions: Seq[String] = Seq( + "current_timezone()", "current_database()", "md5('Spark' collate unicode)", + "soundex('Spark' collate unicode)", "url_encode('https://spark.apache.org' collate unicode)", + "url_decode('https%3A%2F%2Fspark.apache.org')", "uuid()", "chr(65)", "collation('UNICODE')", + "version()", "space(5)", "randstr(5, 123)" + ) + def dataSource: String = "parquet" def testTable: String = "test_tbl" def testView: String = "test_view" @@ -329,6 +336,58 @@ class DefaultCollationTestSuiteV1 extends DefaultCollationTestSuite { } } } + + test("view has utf8 binary collation by default") { + withView(testTable) { + sql(s"CREATE VIEW $testTable AS SELECT current_database() AS db") + assertTableColumnCollation(testTable, "db", "UTF8_BINARY") + } + } + + test("default string producing expressions in view definition") { + val viewDefaultCollation = Seq( + "UTF8_BINARY", "UNICODE" + ) + + viewDefaultCollation.foreach { collation => + withView(testTable) { + + val columns = defaultStringProducingExpressions.zipWithIndex.map { + case (expr, index) => s"$expr AS c${index + 1}" + }.mkString(", ") + + sql( + s""" + |CREATE view $testTable + |DEFAULT COLLATION $collation + |AS SELECT $columns + |""".stripMargin) + + (1 to defaultStringProducingExpressions.length).foreach { index => + assertTableColumnCollation(testTable, s"c$index", collation) + } + } + } + } + + test("default string producing expressions in view definition - nested in expr tree") { + withView(testTable) { + sql( + s""" + |CREATE view $testTable + |DEFAULT COLLATION UNICODE AS SELECT + |SUBSTRING(current_database(), 1, 1) AS c1, + |SUBSTRING(SUBSTRING(current_database(), 1, 2), 1, 1) AS c2, + |SUBSTRING(current_database()::STRING, 1, 1) AS c3, + |SUBSTRING(CAST(current_database() AS STRING COLLATE UTF8_BINARY), 1, 1) AS c4 + |""".stripMargin) + + assertTableColumnCollation(testTable, "c1", "UNICODE") + assertTableColumnCollation(testTable, "c2", "UNICODE") + assertTableColumnCollation(testTable, "c3", "UNICODE") + assertTableColumnCollation(testTable, "c4", "UTF8_BINARY") + } + } } class DefaultCollationTestSuiteV2 extends DefaultCollationTestSuite with DatasourceV2SQLBase {