@@ -1094,6 +1094,125 @@ class SparkSQLLineageParserHelperSuite extends KyuubiFunSuite
1094
1094
}
1095
1095
}
1096
1096
1097
+ test(" test group by" ) {
1098
+ withTable(" t1" , " t2" , " v2_catalog.db.t1" , " v2_catalog.db.t2" ) { _ =>
1099
+ spark.sql(" CREATE TABLE t1 (a string, b string, c string) USING hive" )
1100
+ spark.sql(" CREATE TABLE t2 (a string, b string, c string) USING hive" )
1101
+ spark.sql(" CREATE TABLE v2_catalog.db.t1 (a string, b string, c string)" )
1102
+ spark.sql(" CREATE TABLE v2_catalog.db.t2 (a string, b string, c string)" )
1103
+ val ret0 =
1104
+ exectractLineage(
1105
+ s " insert into table t1 select a, " +
1106
+ s " concat_ws('/', collect_set(b)), " +
1107
+ s " count(distinct(b)) * count(distinct(c)) " +
1108
+ s " from t2 group by a " )
1109
+ assert(ret0 == Lineage (
1110
+ List (" default.t2" ),
1111
+ List (" default.t1" ),
1112
+ List (
1113
+ (" default.t1.a" , Set (" default.t2.a" )),
1114
+ (" default.t1.b" , Set (" default.t2.b" )),
1115
+ (" default.t1.c" , Set (" default.t2.b" , " default.t2.c" )))))
1116
+
1117
+ val ret1 =
1118
+ exectractLineage(
1119
+ s " insert into table v2_catalog.db.t1 select a, " +
1120
+ s " concat_ws('/', collect_set(b)), " +
1121
+ s " count(distinct(b)) * count(distinct(c)) " +
1122
+ s " from v2_catalog.db.t2 group by a " )
1123
+ assert(ret1 == Lineage (
1124
+ List (" v2_catalog.db.t2" ),
1125
+ List (" v2_catalog.db.t1" ),
1126
+ List (
1127
+ (" v2_catalog.db.t1.a" , Set (" v2_catalog.db.t2.a" )),
1128
+ (" v2_catalog.db.t1.b" , Set (" v2_catalog.db.t2.b" )),
1129
+ (" v2_catalog.db.t1.c" , Set (" v2_catalog.db.t2.b" , " v2_catalog.db.t2.c" )))))
1130
+
1131
+ val ret2 =
1132
+ exectractLineage(
1133
+ s " insert into table v2_catalog.db.t1 select a, " +
1134
+ s " count(distinct(b+c)), " +
1135
+ s " count(distinct(b)) * count(distinct(c)) " +
1136
+ s " from v2_catalog.db.t2 group by a " )
1137
+ assert(ret2 == Lineage (
1138
+ List (" v2_catalog.db.t2" ),
1139
+ List (" v2_catalog.db.t1" ),
1140
+ List (
1141
+ (" v2_catalog.db.t1.a" , Set (" v2_catalog.db.t2.a" )),
1142
+ (" v2_catalog.db.t1.b" , Set (" v2_catalog.db.t2.b" , " v2_catalog.db.t2.c" )),
1143
+ (" v2_catalog.db.t1.c" , Set (" v2_catalog.db.t2.b" , " v2_catalog.db.t2.c" )))))
1144
+ }
1145
+ }
1146
+
1147
+ test(" test grouping sets" ) {
1148
+ withTable(" t1" , " t2" ) { _ =>
1149
+ spark.sql(" CREATE TABLE t1 (a string, b string, c string) USING hive" )
1150
+ spark.sql(" CREATE TABLE t2 (a string, b string, c string, d string) USING hive" )
1151
+ val ret0 =
1152
+ exectractLineage(
1153
+ s " insert into table t1 select a,b,GROUPING__ID " +
1154
+ s " from t2 group by a,b,c,d grouping sets ((a,b,c), (a,b,d)) " )
1155
+ assert(ret0 == Lineage (
1156
+ List (" default.t2" ),
1157
+ List (" default.t1" ),
1158
+ List (
1159
+ (" default.t1.a" , Set (" default.t2.a" )),
1160
+ (" default.t1.b" , Set (" default.t2.b" )),
1161
+ (" default.t1.c" , Set ()))))
1162
+ }
1163
+ }
1164
+
1165
+ test(" test catch table with window function" ) {
1166
+ withTable(" t1" , " t2" ) { _ =>
1167
+ spark.sql(" CREATE TABLE t1 (a string, b string) USING hive" )
1168
+ spark.sql(" CREATE TABLE t2 (a string, b string) USING hive" )
1169
+
1170
+ spark.sql(
1171
+ s " cache table c1 select * from ( " +
1172
+ s " select a, b, row_number() over (partition by a order by b asc ) rank from t2) " +
1173
+ s " where rank=1 " )
1174
+ val ret0 = exectractLineage(" insert overwrite table t1 select a, b from c1" )
1175
+ assert(ret0 == Lineage (
1176
+ List (" default.t2" ),
1177
+ List (" default.t1" ),
1178
+ List (
1179
+ (" default.t1.a" , Set (" default.t2.a" )),
1180
+ (" default.t1.b" , Set (" default.t2.b" )))))
1181
+
1182
+ val ret1 = exectractLineage(" insert overwrite table t1 select a, rank from c1" )
1183
+ assert(ret1 == Lineage (
1184
+ List (" default.t2" ),
1185
+ List (" default.t1" ),
1186
+ List (
1187
+ (" default.t1.a" , Set (" default.t2.a" )),
1188
+ (" default.t1.b" , Set (" default.t2.a" , " default.t2.b" )))))
1189
+
1190
+ spark.sql(
1191
+ s " cache table c2 select * from ( " +
1192
+ s " select b, a, row_number() over (partition by a order by b asc ) rank from t2) " +
1193
+ s " where rank=1 " )
1194
+ val ret2 = exectractLineage(" insert overwrite table t1 select a, b from c2" )
1195
+ assert(ret2 == Lineage (
1196
+ List (" default.t2" ),
1197
+ List (" default.t1" ),
1198
+ List (
1199
+ (" default.t1.a" , Set (" default.t2.a" )),
1200
+ (" default.t1.b" , Set (" default.t2.b" )))))
1201
+
1202
+ spark.sql(
1203
+ s " cache table c3 select * from ( " +
1204
+ s " select a as aa, b as bb, row_number() over (partition by a order by b asc ) rank " +
1205
+ s " from t2) where rank=1 " )
1206
+ val ret3 = exectractLineage(" insert overwrite table t1 select aa, bb from c3" )
1207
+ assert(ret3 == Lineage (
1208
+ List (" default.t2" ),
1209
+ List (" default.t1" ),
1210
+ List (
1211
+ (" default.t1.a" , Set (" default.t2.a" )),
1212
+ (" default.t1.b" , Set (" default.t2.b" )))))
1213
+ }
1214
+ }
1215
+
1097
1216
private def exectractLineageWithoutExecuting (sql : String ): Lineage = {
1098
1217
val parsed = spark.sessionState.sqlParser.parsePlan(sql)
1099
1218
val analyzed = spark.sessionState.analyzer.execute(parsed)
0 commit comments