Skip to content

Commit

Permalink
Refactor!: treat Nullable as an arg instead of a DataType.TYPE (#4094)
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas authored Sep 9, 2024
1 parent 5733600 commit 9c527b5
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 40 deletions.
19 changes: 9 additions & 10 deletions sqlglot/dialects/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,12 +474,12 @@ def _parse_types(
dtype = super()._parse_types(
check_func=check_func, schema=schema, allow_identifiers=allow_identifiers
)
if isinstance(dtype, exp.DataType):
# Mark every type as non-nullable which is ClickHouse's default. This marker
# helps us transpile types from other dialects to ClickHouse, so that we can
# e.g. produce `CAST(x AS Nullable(String))` from `CAST(x AS TEXT)`. If there
# is a `NULL` value in `x`, the former would fail in ClickHouse without the
# `Nullable` type constructor
if isinstance(dtype, exp.DataType) and dtype.args.get("nullable") is not True:
# Mark every type as non-nullable which is ClickHouse's default, unless it's
# already marked as nullable. This marker helps us transpile types from other
# dialects to ClickHouse, so that we can e.g. produce `CAST(x AS Nullable(String))`
# from `CAST(x AS TEXT)`. If there is a `NULL` value in `x`, the former would
# fail in ClickHouse without the `Nullable` type constructor.
dtype.set("nullable", False)

return dtype
Expand Down Expand Up @@ -815,7 +815,6 @@ class Generator(generator.Generator):
exp.DataType.Type.LOWCARDINALITY: "LowCardinality",
exp.DataType.Type.MAP: "Map",
exp.DataType.Type.NESTED: "Nested",
exp.DataType.Type.NULLABLE: "Nullable",
exp.DataType.Type.SMALLINT: "Int16",
exp.DataType.Type.STRUCT: "Tuple",
exp.DataType.Type.TINYINT: "Int8",
Expand Down Expand Up @@ -921,7 +920,6 @@ class Generator(generator.Generator):
NON_NULLABLE_TYPES = {
exp.DataType.Type.ARRAY,
exp.DataType.Type.MAP,
exp.DataType.Type.NULLABLE,
exp.DataType.Type.STRUCT,
}

Expand Down Expand Up @@ -1004,8 +1002,9 @@ def datatype_sql(self, expression: exp.DataType) -> str:
# String or FixedString (possibly LowCardinality) or UUID or IPv6"
# - It's not a composite type, e.g. `Nullable(Array(...))` is not a valid type
parent = expression.parent
if (
expression.args.get("nullable") is not False
nullable = expression.args.get("nullable")
if nullable is True or (
nullable is None
and not (
isinstance(parent, exp.DataType)
and parent.is_type(exp.DataType.Type.MAP, check_nullable=True)
Expand Down
3 changes: 0 additions & 3 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,6 @@ def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[
if enum not in ("", "bigquery"):
klass.generator_class.SELECT_KINDS = ()

if enum not in ("", "clickhouse"):
klass.generator_class.SUPPORTS_NULLABLE_TYPES = False

if enum not in ("", "athena", "presto", "trino"):
klass.generator_class.TRY_SUPPORTED = False
klass.generator_class.SUPPORTS_UESCAPE = False
Expand Down
26 changes: 6 additions & 20 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4118,7 +4118,6 @@ class Type(AutoName):
NCHAR = auto()
NESTED = auto()
NULL = auto()
NULLABLE = auto()
NUMMULTIRANGE = auto()
NUMRANGE = auto()
NVARCHAR = auto()
Expand Down Expand Up @@ -4312,32 +4311,19 @@ def is_type(self, *dtypes: DATA_TYPE, check_nullable: bool = False) -> bool:
Returns:
True, if and only if there is a type in `dtypes` which is equal to this DataType.
"""
if (
not check_nullable
and self.this == DataType.Type.NULLABLE
and len(self.expressions) == 1
):
this_type = self.expressions[0]
else:
this_type = self

self_is_nullable = self.args.get("nullable")
for dtype in dtypes:
other_type = DataType.build(dtype, copy=False, udt=True)
if (
not check_nullable
and other_type.this == DataType.Type.NULLABLE
and len(other_type.expressions) == 1
):
other_type = other_type.expressions[0]

other_is_nullable = other_type.args.get("nullable")
if (
other_type.expressions
or this_type.this == DataType.Type.USERDEFINED
or (check_nullable and (self_is_nullable or other_is_nullable))
or self.this == DataType.Type.USERDEFINED
or other_type.this == DataType.Type.USERDEFINED
):
matches = this_type == other_type
matches = self == other_type
else:
matches = this_type.this == other_type.this
matches = self.this == other_type.this

if matches:
return True
Expand Down
7 changes: 1 addition & 6 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,9 +390,6 @@ class Generator(metaclass=_Generator):
# Whether CONVERT_TIMEZONE() is supported; if not, it will be generated as exp.AtTimeZone
SUPPORTS_CONVERT_TIMEZONE = False

# Whether nullable types can be constructed, e.g. `Nullable(Int64)`
SUPPORTS_NULLABLE_TYPES = True

# The name to generate for the JSONPath expression. If `None`, only `this` will be generated
PARSE_JSON_NAME: t.Optional[str] = "PARSE_JSON"

Expand Down Expand Up @@ -1239,14 +1236,12 @@ def datatype_sql(self, expression: exp.DataType) -> str:
type_value = expression.this
if type_value == exp.DataType.Type.USERDEFINED and expression.args.get("kind"):
type_sql = self.sql(expression, "kind")
elif type_value != exp.DataType.Type.NULLABLE or self.SUPPORTS_NULLABLE_TYPES:
else:
type_sql = (
self.TYPE_MAPPING.get(type_value, type_value.value)
if isinstance(type_value, exp.DataType.Type)
else type_value
)
else:
return interior

if interior:
if expression.args.get("nested"):
Expand Down
5 changes: 5 additions & 0 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4753,6 +4753,11 @@ def _parse_types(
check_func=check_func, schema=schema, allow_identifiers=allow_identifiers
)
)
if type_token == TokenType.NULLABLE and len(expressions) == 1:
this = expressions[0]
this.set("nullable", True)
self._match_r_paren()
return this
elif type_token in self.ENUM_TYPE_TOKENS:
expressions = self._parse_csv(self._parse_equality)
elif is_aggregate:
Expand Down
1 change: 0 additions & 1 deletion tests/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,7 +1035,6 @@ def test_data_type_builder(self):
self.assertEqual(exp.DataType.build("GEOGRAPHY").sql(), "GEOGRAPHY")
self.assertEqual(exp.DataType.build("GEOMETRY").sql(), "GEOMETRY")
self.assertEqual(exp.DataType.build("STRUCT").sql(), "STRUCT")
self.assertEqual(exp.DataType.build("NULLABLE").sql(), "NULLABLE")
self.assertEqual(exp.DataType.build("HLLSKETCH", dialect="redshift").sql(), "HLLSKETCH")
self.assertEqual(exp.DataType.build("HSTORE", dialect="postgres").sql(), "HSTORE")
self.assertEqual(exp.DataType.build("NULL").sql(), "NULL")
Expand Down

0 comments on commit 9c527b5

Please sign in to comment.