Skip to content

Commit 0f5da99

Browse files
committed
Dataset builder: add _to_athena_query method
1 parent 342fbbc commit 0f5da99

File tree

1 file changed

+63
-43
lines changed

1 file changed

+63
-43
lines changed

src/sagemaker/feature_store/dataset_builder.py

+63-43
Original file line numberDiff line numberDiff line change
@@ -438,54 +438,13 @@ def to_csv_file(self) -> Tuple[str, str]:
438438
os.remove(local_file_name)
439439
temp_table_name = f'dataframe_{temp_id.replace("-", "_")}'
440440
self._create_temp_table(temp_table_name, desired_s3_folder)
441-
base_features = list(self._base.columns)
442-
event_time_identifier_feature_dtype = self._base[
443-
self._event_time_identifier_feature_name
444-
].dtypes
445-
self._event_time_identifier_feature_type = (
446-
FeatureGroup.DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.get(
447-
str(event_time_identifier_feature_dtype), None
448-
)
449-
)
450-
query_string = self._construct_query_string(
451-
FeatureGroupToBeMerged(
452-
base_features,
453-
self._included_feature_names if self._included_feature_names else base_features,
454-
self._included_feature_names if self._included_feature_names else base_features,
455-
_DEFAULT_CATALOG,
456-
_DEFAULT_DATABASE,
457-
temp_table_name,
458-
self._record_identifier_feature_name,
459-
FeatureDefinition(
460-
self._event_time_identifier_feature_name,
461-
self._event_time_identifier_feature_type,
462-
),
463-
None,
464-
TableType.DATA_FRAME,
465-
)
466-
)
467-
query_result = self._run_query(query_string, _DEFAULT_CATALOG, _DEFAULT_DATABASE)
441+
query_result = self._run_query(*self._to_athena_query(temp_table_name=temp_table_name))
468442
# TODO: cleanup temp table, need more clarification, keep it for now
469443
return query_result.get("QueryExecution", {}).get("ResultConfiguration", {}).get(
470444
"OutputLocation", None
471445
), query_result.get("QueryExecution", {}).get("Query", None)
472446
if isinstance(self._base, FeatureGroup):
473-
base_feature_group = construct_feature_group_to_be_merged(
474-
self._base, self._included_feature_names
475-
)
476-
self._record_identifier_feature_name = base_feature_group.record_identifier_feature_name
477-
self._event_time_identifier_feature_name = (
478-
base_feature_group.event_time_identifier_feature.feature_name
479-
)
480-
self._event_time_identifier_feature_type = (
481-
base_feature_group.event_time_identifier_feature.feature_type
482-
)
483-
query_string = self._construct_query_string(base_feature_group)
484-
query_result = self._run_query(
485-
query_string,
486-
base_feature_group.catalog,
487-
base_feature_group.database,
488-
)
447+
query_result = self._run_query(*self._to_athena_query())
489448
return query_result.get("QueryExecution", {}).get("ResultConfiguration", {}).get(
490449
"OutputLocation", None
491450
), query_result.get("QueryExecution", {}).get("Query", None)
@@ -1058,6 +1017,67 @@ def _construct_athena_table_column_string(self, column: str) -> str:
10581017
raise RuntimeError(f"The dataframe type {dataframe_type} is not supported yet.")
10591018
return f"{column} {self._DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP.get(str(dataframe_type), None)}"
10601019

1020+
def _to_athena_query(self, temp_table_name: str = None) -> Tuple[str, str, str]:
1021+
"""Internal method for constructing an Athena query.
1022+
1023+
Args:
1024+
temp_table_name (str): The temporary Athena table name of the base pandas.DataFrame. Defaults to None.
1025+
1026+
Returns:
1027+
The query string.
1028+
The name of the catalog to be used in the query execution.
1029+
The database to be used in the query execution.
1030+
1031+
Raises:
1032+
ValueError: temp_table_name must be provided if the base is a pandas.DataFrame.
1033+
"""
1034+
if isinstance(self._base, pd.DataFrame):
1035+
if temp_table_name is None:
1036+
raise ValueError("temp_table_name must be provided for a pandas.DataFrame base.")
1037+
base_features = list(self._base.columns)
1038+
event_time_identifier_feature_dtype = self._base[
1039+
self._event_time_identifier_feature_name
1040+
].dtypes
1041+
self._event_time_identifier_feature_type = (
1042+
FeatureGroup.DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.get(
1043+
str(event_time_identifier_feature_dtype), None
1044+
)
1045+
)
1046+
catalog = _DEFAULT_CATALOG
1047+
database = _DEFAULT_DATABASE
1048+
query_string = self._construct_query_string(
1049+
FeatureGroupToBeMerged(
1050+
base_features,
1051+
self._included_feature_names if self._included_feature_names else base_features,
1052+
self._included_feature_names if self._included_feature_names else base_features,
1053+
catalog,
1054+
database,
1055+
temp_table_name,
1056+
self._record_identifier_feature_name,
1057+
FeatureDefinition(
1058+
self._event_time_identifier_feature_name,
1059+
self._event_time_identifier_feature_type,
1060+
),
1061+
None,
1062+
TableType.DATA_FRAME,
1063+
)
1064+
)
1065+
if isinstance(self._base, FeatureGroup):
1066+
base_feature_group = construct_feature_group_to_be_merged(
1067+
self._base, self._included_feature_names
1068+
)
1069+
self._record_identifier_feature_name = base_feature_group.record_identifier_feature_name
1070+
self._event_time_identifier_feature_name = (
1071+
base_feature_group.event_time_identifier_feature.feature_name
1072+
)
1073+
self._event_time_identifier_feature_type = (
1074+
base_feature_group.event_time_identifier_feature.feature_type
1075+
)
1076+
catalog = base_feature_group.catalog
1077+
database = base_feature_group.database
1078+
query_string = self._construct_query_string(base_feature_group)
1079+
return query_string, catalog, database
1080+
10611081
def _run_query(self, query_string: str, catalog: str, database: str) -> Dict[str, Any]:
10621082
"""Internal method for execute Athena query, wait for query finish and get query result.
10631083

0 commit comments

Comments
 (0)