Skip to content

Commit 9e4205e

Browse files
iodoneyanghua
authored andcommitted
[KYUUBI apache#4393] [Kyuubi apache#4332] Fix some bugs with Groupby and CacheTable
close apache#4332 ### _Why are the changes needed?_ For the case where the table name has been resolved and an `Expand` logical plan exists ``` InsertIntoHiveTable `default`.`t1`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, false, false, [a, b] +- Aggregate [a#0], [a#0, ansi_cast((count(if ((gid#9 = 1)) spark_catalog.default.t2.`b`#10 else null) * count(if ((gid#9 = 2)) spark_catalog.default.t2.`c`#11 else null)) as string) AS b#8] +- Aggregate [a#0, spark_catalog.default.t2.`b`#10, spark_catalog.default.t2.`c`#11, gid#9], [a#0, spark_catalog.default.t2.`b`#10, spark_catalog.default.t2.`c`#11, gid#9] +- Expand [ArrayBuffer(a#0, b#1, null, 1), ArrayBuffer(a#0, null, c#2, 2)], [a#0, spark_catalog.default.t2.`b`#10, spark_catalog.default.t2.`c`#11, gid#9] +- HiveTableRelation [`default`.`t2`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, Data Cols: [a#0, b#1, c#2], Partition Cols: []] ``` For the case `CacheTable` with `window` function ``` InsertIntoHiveTable `default`.`t1`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, true, false, [a, b] +- Project [a#98, b#99] +- InMemoryRelation [a#98, b#99, rank#100], StorageLevel(disk, memory, deserialized, 1 replicas) +- *(2) Filter (isnotnull(rank#4) AND (rank#4 = 1)) +- Window [row_number() windowspecdefinition(a#9, b#10 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rank#4], [a#9], [b#10 ASC NULLS FIRST] +- *(1) Sort [a#9 ASC NULLS FIRST, b#10 ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(a#9, 200), ENSURE_REQUIREMENTS, [id=apache#38] +- Scan hive default.t2 [a#9, b#10], HiveTableRelation [`default`.`t2`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, Data Cols: [a#9, b#10], Partition Cols: []] ``` ### _How was this patch tested?_ - [x] Add some test cases that check the changes thoroughly including negative and positive cases if possible - [ ] Add screenshots for manual tests if appropriate - [x] [Run test](https://kyuubi.readthedocs.io/en/master/develop_tools/testing.html#running-tests) locally before make a pull request Closes apache#4393 from iodone/kyuubi-4332. Closes apache#4393 d2afdab [odone] fix cache table bug 443af79 [odone] fix some bugs with groupby Authored-by: odone <[email protected]> Signed-off-by: ulyssesyou <[email protected]>
1 parent e4be464 commit 9e4205e

File tree

2 files changed

+155
-3
lines changed

2 files changed

+155
-3
lines changed

extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParseHelper.scala

+36-3
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ import org.apache.spark.sql.SparkSession
2525
import org.apache.spark.sql.catalyst.TableIdentifier
2626
import org.apache.spark.sql.catalyst.analysis.{NamedRelation, PersistedView, ViewType}
2727
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, HiveTableRelation}
28-
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, Expression, NamedExpression}
29-
import org.apache.spark.sql.catalyst.expressions.ScalarSubquery
28+
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, Expression, NamedExpression, ScalarSubquery}
3029
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
3130
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi}
3231
import org.apache.spark.sql.catalyst.plans.logical._
@@ -128,7 +127,7 @@ trait LineageParser {
128127
exp.toAttribute,
129128
if (!containsCountAll(exp.child)) references
130129
else references + exp.toAttribute.withName(AGGREGATE_COUNT_COLUMN_IDENTIFIER))
131-
case a: Attribute => a -> a.references
130+
case a: Attribute => a -> AttributeSet(a)
132131
}
133132
ListMap(exps: _*)
134133
}
@@ -149,6 +148,9 @@ trait LineageParser {
149148
attr.withQualifier(attr.qualifier.init)
150149
case attr if attr.name.equalsIgnoreCase(AGGREGATE_COUNT_COLUMN_IDENTIFIER) =>
151150
attr.withQualifier(qualifier)
151+
case attr if isNameWithQualifier(attr, qualifier) =>
152+
val newName = attr.name.split('.').last.stripPrefix("`").stripSuffix("`")
153+
attr.withName(newName).withQualifier(qualifier)
152154
})
153155
}
154156
} else {
@@ -160,6 +162,12 @@ trait LineageParser {
160162
}
161163
}
162164

165+
private def isNameWithQualifier(attr: Attribute, qualifier: Seq[String]): Boolean = {
166+
val nameTokens = attr.name.split('.')
167+
val namespace = nameTokens.init.mkString(".")
168+
nameTokens.length > 1 && namespace.endsWith(qualifier.mkString("."))
169+
}
170+
163171
private def mergeRelationColumnLineage(
164172
parentColumnsLineage: AttributeMap[AttributeSet],
165173
relationOutput: Seq[Attribute],
@@ -327,6 +335,31 @@ trait LineageParser {
327335
joinColumnsLineage(parentColumnsLineage, getSelectColumnLineage(p.aggregateExpressions))
328336
p.children.map(extractColumnsLineage(_, nextColumnsLineage)).reduce(mergeColumnsLineage)
329337

338+
case p: Expand =>
339+
val references =
340+
p.projections.transpose.map(_.flatMap(x => x.references)).map(AttributeSet(_))
341+
342+
val childColumnsLineage = ListMap(p.output.zip(references): _*)
343+
val nextColumnsLineage =
344+
joinColumnsLineage(parentColumnsLineage, childColumnsLineage)
345+
p.children.map(extractColumnsLineage(_, nextColumnsLineage)).reduce(mergeColumnsLineage)
346+
347+
case p: Window =>
348+
val windowColumnsLineage =
349+
ListMap(p.windowExpressions.map(exp => (exp.toAttribute, exp.references)): _*)
350+
351+
val nextColumnsLineage = if (parentColumnsLineage.isEmpty) {
352+
ListMap(p.child.output.map(attr => (attr, attr.references)): _*) ++ windowColumnsLineage
353+
} else {
354+
parentColumnsLineage.map {
355+
case (k, _) if windowColumnsLineage.contains(k) =>
356+
k -> windowColumnsLineage(k)
357+
case (k, attrs) =>
358+
k -> AttributeSet(attrs.flatten(attr =>
359+
windowColumnsLineage.getOrElse(attr, AttributeSet(attr))))
360+
}
361+
}
362+
p.children.map(extractColumnsLineage(_, nextColumnsLineage)).reduce(mergeColumnsLineage)
330363
case p: Join =>
331364
p.joinType match {
332365
case LeftSemi | LeftAnti =>

extensions/spark/kyuubi-spark-lineage/src/test/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParserHelperSuite.scala

+119
Original file line numberDiff line numberDiff line change
@@ -1094,6 +1094,125 @@ class SparkSQLLineageParserHelperSuite extends KyuubiFunSuite
10941094
}
10951095
}
10961096

1097+
test("test group by") {
1098+
withTable("t1", "t2", "v2_catalog.db.t1", "v2_catalog.db.t2") { _ =>
1099+
spark.sql("CREATE TABLE t1 (a string, b string, c string) USING hive")
1100+
spark.sql("CREATE TABLE t2 (a string, b string, c string) USING hive")
1101+
spark.sql("CREATE TABLE v2_catalog.db.t1 (a string, b string, c string)")
1102+
spark.sql("CREATE TABLE v2_catalog.db.t2 (a string, b string, c string)")
1103+
val ret0 =
1104+
exectractLineage(
1105+
s"insert into table t1 select a," +
1106+
s"concat_ws('/', collect_set(b))," +
1107+
s"count(distinct(b)) * count(distinct(c))" +
1108+
s"from t2 group by a")
1109+
assert(ret0 == Lineage(
1110+
List("default.t2"),
1111+
List("default.t1"),
1112+
List(
1113+
("default.t1.a", Set("default.t2.a")),
1114+
("default.t1.b", Set("default.t2.b")),
1115+
("default.t1.c", Set("default.t2.b", "default.t2.c")))))
1116+
1117+
val ret1 =
1118+
exectractLineage(
1119+
s"insert into table v2_catalog.db.t1 select a," +
1120+
s"concat_ws('/', collect_set(b))," +
1121+
s"count(distinct(b)) * count(distinct(c))" +
1122+
s"from v2_catalog.db.t2 group by a")
1123+
assert(ret1 == Lineage(
1124+
List("v2_catalog.db.t2"),
1125+
List("v2_catalog.db.t1"),
1126+
List(
1127+
("v2_catalog.db.t1.a", Set("v2_catalog.db.t2.a")),
1128+
("v2_catalog.db.t1.b", Set("v2_catalog.db.t2.b")),
1129+
("v2_catalog.db.t1.c", Set("v2_catalog.db.t2.b", "v2_catalog.db.t2.c")))))
1130+
1131+
val ret2 =
1132+
exectractLineage(
1133+
s"insert into table v2_catalog.db.t1 select a," +
1134+
s"count(distinct(b+c))," +
1135+
s"count(distinct(b)) * count(distinct(c))" +
1136+
s"from v2_catalog.db.t2 group by a")
1137+
assert(ret2 == Lineage(
1138+
List("v2_catalog.db.t2"),
1139+
List("v2_catalog.db.t1"),
1140+
List(
1141+
("v2_catalog.db.t1.a", Set("v2_catalog.db.t2.a")),
1142+
("v2_catalog.db.t1.b", Set("v2_catalog.db.t2.b", "v2_catalog.db.t2.c")),
1143+
("v2_catalog.db.t1.c", Set("v2_catalog.db.t2.b", "v2_catalog.db.t2.c")))))
1144+
}
1145+
}
1146+
1147+
test("test grouping sets") {
1148+
withTable("t1", "t2") { _ =>
1149+
spark.sql("CREATE TABLE t1 (a string, b string, c string) USING hive")
1150+
spark.sql("CREATE TABLE t2 (a string, b string, c string, d string) USING hive")
1151+
val ret0 =
1152+
exectractLineage(
1153+
s"insert into table t1 select a,b,GROUPING__ID " +
1154+
s"from t2 group by a,b,c,d grouping sets ((a,b,c), (a,b,d))")
1155+
assert(ret0 == Lineage(
1156+
List("default.t2"),
1157+
List("default.t1"),
1158+
List(
1159+
("default.t1.a", Set("default.t2.a")),
1160+
("default.t1.b", Set("default.t2.b")),
1161+
("default.t1.c", Set()))))
1162+
}
1163+
}
1164+
1165+
test("test catch table with window function") {
1166+
withTable("t1", "t2") { _ =>
1167+
spark.sql("CREATE TABLE t1 (a string, b string) USING hive")
1168+
spark.sql("CREATE TABLE t2 (a string, b string) USING hive")
1169+
1170+
spark.sql(
1171+
s"cache table c1 select * from (" +
1172+
s"select a, b, row_number() over (partition by a order by b asc ) rank from t2)" +
1173+
s" where rank=1")
1174+
val ret0 = exectractLineage("insert overwrite table t1 select a, b from c1")
1175+
assert(ret0 == Lineage(
1176+
List("default.t2"),
1177+
List("default.t1"),
1178+
List(
1179+
("default.t1.a", Set("default.t2.a")),
1180+
("default.t1.b", Set("default.t2.b")))))
1181+
1182+
val ret1 = exectractLineage("insert overwrite table t1 select a, rank from c1")
1183+
assert(ret1 == Lineage(
1184+
List("default.t2"),
1185+
List("default.t1"),
1186+
List(
1187+
("default.t1.a", Set("default.t2.a")),
1188+
("default.t1.b", Set("default.t2.a", "default.t2.b")))))
1189+
1190+
spark.sql(
1191+
s"cache table c2 select * from (" +
1192+
s"select b, a, row_number() over (partition by a order by b asc ) rank from t2)" +
1193+
s" where rank=1")
1194+
val ret2 = exectractLineage("insert overwrite table t1 select a, b from c2")
1195+
assert(ret2 == Lineage(
1196+
List("default.t2"),
1197+
List("default.t1"),
1198+
List(
1199+
("default.t1.a", Set("default.t2.a")),
1200+
("default.t1.b", Set("default.t2.b")))))
1201+
1202+
spark.sql(
1203+
s"cache table c3 select * from (" +
1204+
s"select a as aa, b as bb, row_number() over (partition by a order by b asc ) rank" +
1205+
s" from t2) where rank=1")
1206+
val ret3 = exectractLineage("insert overwrite table t1 select aa, bb from c3")
1207+
assert(ret3 == Lineage(
1208+
List("default.t2"),
1209+
List("default.t1"),
1210+
List(
1211+
("default.t1.a", Set("default.t2.a")),
1212+
("default.t1.b", Set("default.t2.b")))))
1213+
}
1214+
}
1215+
10971216
private def exectractLineageWithoutExecuting(sql: String): Lineage = {
10981217
val parsed = spark.sessionState.sqlParser.parsePlan(sql)
10991218
val analyzed = spark.sessionState.analyzer.execute(parsed)

0 commit comments

Comments
 (0)