Skip to content

Commit

Permalink
Fixes mode and metal transactions (#2160)
Browse files Browse the repository at this point in the history
* Fixes mode and metal transactions

* Allow forward only changes

* handle repeated columns
  • Loading branch information
ravenac95 authored Sep 17, 2024
1 parent 7cf3413 commit 4540f34
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 11 deletions.
2 changes: 1 addition & 1 deletion warehouse/oso_dagster/assets/metal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
network_name="metal",
destination_dataset_name="superchain",
working_destination_dataset_name="oso_raw_sources",
transactions_config={"source_name": "metal-receipt_transactions"},
transactions_config={"source_name": "metal-enriched_transactions"},
)
2 changes: 1 addition & 1 deletion warehouse/oso_dagster/assets/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
network_name="mode",
destination_dataset_name="superchain",
working_destination_dataset_name="oso_raw_sources",
transactions_config={"source_name": "mode-receipt_transactions"},
transactions_config={"source_name": "mode-enriched_transactions"},
)
55 changes: 54 additions & 1 deletion warehouse/oso_dagster/factories/goldsky/assets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import heapq
import io
import logging
import os
import random
import re
Expand All @@ -23,6 +24,7 @@
from google.cloud.bigquery import Client as BQClient
from google.cloud.bigquery import LoadJobConfig, SourceFormat, TableReference
from google.cloud.bigquery.schema import SchemaField
from oso_dagster.utils.bq import compare_schemas, get_table_schema
from polars.type_aliases import PolarsDataType

from ...cbt import CBTResource, TimePartitioning, UpdateStrategy
Expand Down Expand Up @@ -523,6 +525,11 @@ def load_schema(self, queues: GoldskyQueues):
finally:
client.close()

def load_schema_for_bq_table(self, table_ref: str):
with self.bigquery.get_client() as client:
return get_table_schema(client, table_ref)


def ensure_datasets(self, context: GenericExecutionContext):
self.ensure_dataset(context, self.config.destination_dataset_name)
self.ensure_dataset(context, self.config.working_destination_dataset_name)
Expand Down Expand Up @@ -647,7 +654,10 @@ async def merge_worker_tables(

worker_deduped_table = self.config.worker_deduped_table_fqdn(workers[0].name)

context.log.warn(f"Worker table to use for schema {worker_deduped_table}")
context.log.warning(f"Worker table to use for schema {worker_deduped_table}")

# check the schema of the destination and the worker table. if it's only new rows then add those rose
self.ensure_schema_or_fail(context.log, worker_deduped_table, self.config.destination_table_fqn)

cbt.transform(
self.config.merge_workers_model,
Expand All @@ -663,6 +673,49 @@ async def merge_worker_tables(
source_table_fqn=worker_deduped_table,
)

def ensure_schema_or_fail(self, log: logging.Logger, source_table: str, destination_table: str):
source_schema = self.load_schema_for_bq_table(source_table)
destination_schema = self.load_schema_for_bq_table(destination_table)

source_only, destination_only, modified = compare_schemas(source_schema, destination_schema)
if len(modified) > 0:
log.error(dict(
msg=f"cannot merge automatically into {destination_table} schema has been altered:",
destination_only=destination_only,
source_only=source_only,
modified=modified,
))
raise Exception(f"cannot merge, schemas incompatible in {source_schema} and {destination_table}")
# IF things are only in the destination that just means the new data removed a column. We can log and ignore.
if len(destination_only) > 0:
log.warning(dict(
msg="new data no longer has some columns",
columns=destination_only,
))
# If only the source or only the destination are different then we can update
if len(source_only) > 0:
log.info(dict(
msg="updating table to include new columns from the source data",
columns=source_only,
))
with self.bigquery.get_client() as client:
table = client.get_table(destination_table)
updated_schema = table.schema[:]
# Force all of the fields to be nullable
new_fields: List[SchemaField] = []
for field in source_only.values():
if field.mode not in ["NULLABLE", "REPEATED"]:
field_dict = field.to_api_repr()
field_dict["mode"] = "NULLABLE"
new_fields.append(SchemaField.from_api_repr(field_dict))
else:
new_fields.append(field)
updated_schema.extend(new_fields)
table.schema = updated_schema

client.update_table(table, ["schema"])


async def clean_working_destination(
self, context: GenericExecutionContext, workers: List[GoldskyWorker]
):
Expand Down
114 changes: 106 additions & 8 deletions warehouse/oso_dagster/utils/bq.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from typing import List, Optional
import typing as t
from dataclasses import dataclass

from google.cloud.bigquery import DatasetReference, AccessEntry, Client as BQClient, ExtractJobConfig
from google.cloud import bigquery
from google.cloud.bigquery import AccessEntry
from google.cloud.bigquery import Client as BQClient
from google.cloud.bigquery import DatasetReference, ExtractJobConfig
from google.cloud.bigquery.enums import EntityTypes
from google.cloud.bigquery.schema import SchemaField
from google.cloud.exceptions import NotFound, PreconditionFailed

from .retry import retry


# Configuration for a BigQuery Dataset
@dataclass(kw_only=True)
class BigQueryDatasetConfig:
Expand All @@ -14,18 +20,21 @@ class BigQueryDatasetConfig:
# BigQuery dataset
dataset_name: str
# Service account
service_account: Optional[str]
service_account: t.Optional[str]


@dataclass(kw_only=True)
class BigQueryTableConfig(BigQueryDatasetConfig):
# BigQuery table
table_name: str


@dataclass
class DatasetOptions:
dataset_ref: DatasetReference
is_public: bool = False


def ensure_dataset(client: BQClient, options: DatasetOptions):
"""
Create a public dataset if missing
Expand All @@ -37,7 +46,7 @@ def ensure_dataset(client: BQClient, options: DatasetOptions):
The Google BigQuery client
options: DatasetOptions
"""

try:
client.get_dataset(options.dataset_ref)
except NotFound:
Expand All @@ -52,7 +61,7 @@ def retry_update():
if entry.entity_id == "allAuthenticatedUsers" and entry.role == "READER":
return

new_entries: List[AccessEntry] = []
new_entries: t.List[AccessEntry] = []
if options.is_public:
new_entries.append(
AccessEntry(
Expand All @@ -71,7 +80,10 @@ def error_handler(exc: Exception):

retry(retry_update, error_handler)

def export_to_gcs(bq_client: BQClient, bq_table_config: BigQueryTableConfig, gcs_path: str):

def export_to_gcs(
bq_client: BQClient, bq_table_config: BigQueryTableConfig, gcs_path: str
):
"""
Export a BigQuery table to partitioned CSV files in GCS
Expand Down Expand Up @@ -103,8 +115,94 @@ def export_to_gcs(bq_client: BQClient, bq_table_config: BigQueryTableConfig, gcs
job_config=ExtractJobConfig(
print_header=False,
destination_format="PARQUET",
#compression="ZSTD",
# compression="ZSTD",
),
)
extract_job.result()
return destination_uri
return destination_uri


def get_table_schema(
client: bigquery.Client, table_ref: bigquery.TableReference | str
) -> t.List[SchemaField]:
"""Fetches the schema of a table."""
table = client.get_table(table_ref)
return table.schema


def compare_schemas(
schema1: t.List[SchemaField], schema2: t.List[SchemaField]
) -> t.Tuple[
t.Dict[str, SchemaField],
t.Dict[str, SchemaField],
t.Dict[str, t.Dict[str, SchemaField]],
]:
"""Compares two BigQuery schemas and outputs the differences.
Returns a tuple containing:
- Fields only in schema1
- Fields only in schema2
- Fields present in both schemas but with different properties
"""
schema1_fields: t.Dict[str, SchemaField] = {field.name: field for field in schema1}
schema2_fields: t.Dict[str, SchemaField] = {field.name: field for field in schema2}

# Fields only in schema1
schema1_only: t.Dict[str, SchemaField] = {
name: schema1_fields[name]
for name in schema1_fields
if name not in schema2_fields
}

# Fields only in schema2
schema2_only: t.Dict[str, SchemaField] = {
name: schema2_fields[name]
for name in schema2_fields
if name not in schema1_fields
}

# Fields in both schemas but with different properties
modified_fields: t.Dict[str, t.Dict[str, SchemaField]] = {}
for name in schema1_fields:
if name in schema2_fields:
if schema1_fields[name] != schema2_fields[name]:
modified_fields[name] = {
"schema1": schema1_fields[name],
"schema2": schema2_fields[name],
}

return schema1_only, schema2_only, modified_fields


def print_schema_diff(
schema1_only: t.Dict[str, SchemaField],
schema2_only: t.Dict[str, SchemaField],
modified_fields: t.Dict[str, t.Dict[str, SchemaField]],
) -> None:
"""Prints the schema differences."""
if schema1_only:
print("Fields only in Schema 1:")
for field_name, field in schema1_only.items():
print(f" - {field_name}: {field.field_type}, {field.mode}")
else:
print("No fields unique to Schema 1.")

if schema2_only:
print("Fields only in Schema 2:")
for field_name, field in schema2_only.items():
print(f" - {field_name}: {field.field_type}, {field.mode}")
else:
print("No fields unique to Schema 2.")

if modified_fields:
print("Fields with differences:")
for field_name, fields in modified_fields.items():
print(f" - {field_name}:")
print(
f" Schema 1: {fields['schema1'].field_type}, {fields['schema1'].mode}"
)
print(
f" Schema 2: {fields['schema2'].field_type}, {fields['schema2'].mode}"
)
else:
print("No modified fields.")

0 comments on commit 4540f34

Please sign in to comment.