@@ -438,54 +438,13 @@ def to_csv_file(self) -> Tuple[str, str]:
438
438
os .remove (local_file_name )
439
439
temp_table_name = f'dataframe_{ temp_id .replace ("-" , "_" )} '
440
440
self ._create_temp_table (temp_table_name , desired_s3_folder )
441
- base_features = list (self ._base .columns )
442
- event_time_identifier_feature_dtype = self ._base [
443
- self ._event_time_identifier_feature_name
444
- ].dtypes
445
- self ._event_time_identifier_feature_type = (
446
- FeatureGroup .DTYPE_TO_FEATURE_DEFINITION_CLS_MAP .get (
447
- str (event_time_identifier_feature_dtype ), None
448
- )
449
- )
450
- query_string = self ._construct_query_string (
451
- FeatureGroupToBeMerged (
452
- base_features ,
453
- self ._included_feature_names if self ._included_feature_names else base_features ,
454
- self ._included_feature_names if self ._included_feature_names else base_features ,
455
- _DEFAULT_CATALOG ,
456
- _DEFAULT_DATABASE ,
457
- temp_table_name ,
458
- self ._record_identifier_feature_name ,
459
- FeatureDefinition (
460
- self ._event_time_identifier_feature_name ,
461
- self ._event_time_identifier_feature_type ,
462
- ),
463
- None ,
464
- TableType .DATA_FRAME ,
465
- )
466
- )
467
- query_result = self ._run_query (query_string , _DEFAULT_CATALOG , _DEFAULT_DATABASE )
441
+ query_result = self ._run_query (* self ._to_athena_query (temp_table_name = temp_table_name ))
468
442
# TODO: cleanup temp table, need more clarification, keep it for now
469
443
return query_result .get ("QueryExecution" , {}).get ("ResultConfiguration" , {}).get (
470
444
"OutputLocation" , None
471
445
), query_result .get ("QueryExecution" , {}).get ("Query" , None )
472
446
if isinstance (self ._base , FeatureGroup ):
473
- base_feature_group = construct_feature_group_to_be_merged (
474
- self ._base , self ._included_feature_names
475
- )
476
- self ._record_identifier_feature_name = base_feature_group .record_identifier_feature_name
477
- self ._event_time_identifier_feature_name = (
478
- base_feature_group .event_time_identifier_feature .feature_name
479
- )
480
- self ._event_time_identifier_feature_type = (
481
- base_feature_group .event_time_identifier_feature .feature_type
482
- )
483
- query_string = self ._construct_query_string (base_feature_group )
484
- query_result = self ._run_query (
485
- query_string ,
486
- base_feature_group .catalog ,
487
- base_feature_group .database ,
488
- )
447
+ query_result = self ._run_query (* self ._to_athena_query ())
489
448
return query_result .get ("QueryExecution" , {}).get ("ResultConfiguration" , {}).get (
490
449
"OutputLocation" , None
491
450
), query_result .get ("QueryExecution" , {}).get ("Query" , None )
@@ -1058,6 +1017,67 @@ def _construct_athena_table_column_string(self, column: str) -> str:
1058
1017
raise RuntimeError (f"The dataframe type { dataframe_type } is not supported yet." )
1059
1018
return f"{ column } { self ._DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP .get (str (dataframe_type ), None )} "
1060
1019
1020
+ def _to_athena_query (self , temp_table_name : str = None ) -> Tuple [str , str , str ]:
1021
+ """Internal method for constructing an Athena query.
1022
+
1023
+ Args:
1024
+ temp_table_name (str): The temporary Athena table name of the base pandas.DataFrame. Defaults to None.
1025
+
1026
+ Returns:
1027
+ The query string.
1028
+ The name of the catalog to be used in the query execution.
1029
+ The database to be used in the query execution.
1030
+
1031
+ Raises:
1032
+ ValueError: temp_table_name must be provided if the base is a pandas.DataFrame.
1033
+ """
1034
+ if isinstance (self ._base , pd .DataFrame ):
1035
+ if temp_table_name is None :
1036
+ raise ValueError ("temp_table_name must be provided for a pandas.DataFrame base." )
1037
+ base_features = list (self ._base .columns )
1038
+ event_time_identifier_feature_dtype = self ._base [
1039
+ self ._event_time_identifier_feature_name
1040
+ ].dtypes
1041
+ self ._event_time_identifier_feature_type = (
1042
+ FeatureGroup .DTYPE_TO_FEATURE_DEFINITION_CLS_MAP .get (
1043
+ str (event_time_identifier_feature_dtype ), None
1044
+ )
1045
+ )
1046
+ catalog = _DEFAULT_CATALOG
1047
+ database = _DEFAULT_DATABASE
1048
+ query_string = self ._construct_query_string (
1049
+ FeatureGroupToBeMerged (
1050
+ base_features ,
1051
+ self ._included_feature_names if self ._included_feature_names else base_features ,
1052
+ self ._included_feature_names if self ._included_feature_names else base_features ,
1053
+ catalog ,
1054
+ database ,
1055
+ temp_table_name ,
1056
+ self ._record_identifier_feature_name ,
1057
+ FeatureDefinition (
1058
+ self ._event_time_identifier_feature_name ,
1059
+ self ._event_time_identifier_feature_type ,
1060
+ ),
1061
+ None ,
1062
+ TableType .DATA_FRAME ,
1063
+ )
1064
+ )
1065
+ if isinstance (self ._base , FeatureGroup ):
1066
+ base_feature_group = construct_feature_group_to_be_merged (
1067
+ self ._base , self ._included_feature_names
1068
+ )
1069
+ self ._record_identifier_feature_name = base_feature_group .record_identifier_feature_name
1070
+ self ._event_time_identifier_feature_name = (
1071
+ base_feature_group .event_time_identifier_feature .feature_name
1072
+ )
1073
+ self ._event_time_identifier_feature_type = (
1074
+ base_feature_group .event_time_identifier_feature .feature_type
1075
+ )
1076
+ catalog = base_feature_group .catalog
1077
+ database = base_feature_group .database
1078
+ query_string = self ._construct_query_string (base_feature_group )
1079
+ return query_string , catalog , database
1080
+
1061
1081
def _run_query (self , query_string : str , catalog : str , database : str ) -> Dict [str , Any ]:
1062
1082
"""Internal method for execute Athena query, wait for query finish and get query result.
1063
1083
0 commit comments