From 35c431318308aa00c65ff12e3b0388428edad0a3 Mon Sep 17 00:00:00 2001 From: Jesse Whitehouse Date: Mon, 2 Oct 2023 13:07:28 -0400 Subject: [PATCH 1/2] (1/x) Add tests for UPPPERCASE types. I refactored to reference sqlalchemy.types directly rather than having such a long import list. Also refactored _as_tuple_list into a function so I can reuse it for the uppercase types Signed-off-by: Jesse Whitehouse --- .../sqlalchemy/test_local/test_types.py | 118 ++++++++++-------- src/databricks/sqlalchemy/types.py | 37 ++---- 2 files changed, 80 insertions(+), 75 deletions(-) diff --git a/src/databricks/sqlalchemy/test_local/test_types.py b/src/databricks/sqlalchemy/test_local/test_types.py index 50d1fd85f..06415c2ba 100644 --- a/src/databricks/sqlalchemy/test_local/test_types.py +++ b/src/databricks/sqlalchemy/test_local/test_types.py @@ -1,30 +1,7 @@ import enum import pytest -from sqlalchemy.types import ( - BigInteger, - Boolean, - Date, - DateTime, - Double, - Enum, - Float, - Integer, - Interval, - LargeBinary, - MatchType, - Numeric, - PickleType, - SchemaType, - SmallInteger, - String, - Text, - Time, - TypeEngine, - Unicode, - UnicodeText, - Uuid, -) +import sqlalchemy from databricks.sqlalchemy import DatabricksDialect @@ -55,35 +32,39 @@ class DatabricksDataType(enum.Enum): # Defines the way that SQLAlchemy CamelCase types are compiled into Databricks SQL types. # Note: I wish I could define this within the TestCamelCaseTypesCompilation class, but pytest doesn't like that. camel_case_type_map = { - BigInteger: DatabricksDataType.BIGINT, - LargeBinary: DatabricksDataType.BINARY, - Boolean: DatabricksDataType.BOOLEAN, - Date: DatabricksDataType.DATE, - DateTime: DatabricksDataType.TIMESTAMP, - Double: DatabricksDataType.DOUBLE, - Enum: DatabricksDataType.STRING, - Float: DatabricksDataType.FLOAT, - Integer: DatabricksDataType.INT, - Interval: DatabricksDataType.TIMESTAMP, - Numeric: DatabricksDataType.DECIMAL, - PickleType: DatabricksDataType.BINARY, - SmallInteger: DatabricksDataType.SMALLINT, - String: DatabricksDataType.STRING, - Text: DatabricksDataType.STRING, - Time: DatabricksDataType.STRING, - Unicode: DatabricksDataType.STRING, - UnicodeText: DatabricksDataType.STRING, - Uuid: DatabricksDataType.STRING, + sqlalchemy.types.BigInteger: DatabricksDataType.BIGINT, + sqlalchemy.types.LargeBinary: DatabricksDataType.BINARY, + sqlalchemy.types.Boolean: DatabricksDataType.BOOLEAN, + sqlalchemy.types.Date: DatabricksDataType.DATE, + sqlalchemy.types.DateTime: DatabricksDataType.TIMESTAMP, + sqlalchemy.types.Double: DatabricksDataType.DOUBLE, + sqlalchemy.types.Enum: DatabricksDataType.STRING, + sqlalchemy.types.Float: DatabricksDataType.FLOAT, + sqlalchemy.types.Integer: DatabricksDataType.INT, + sqlalchemy.types.Interval: DatabricksDataType.TIMESTAMP, + sqlalchemy.types.Numeric: DatabricksDataType.DECIMAL, + sqlalchemy.types.PickleType: DatabricksDataType.BINARY, + sqlalchemy.types.SmallInteger: DatabricksDataType.SMALLINT, + sqlalchemy.types.String: DatabricksDataType.STRING, + sqlalchemy.types.Text: DatabricksDataType.STRING, + sqlalchemy.types.Time: DatabricksDataType.STRING, + sqlalchemy.types.Unicode: DatabricksDataType.STRING, + sqlalchemy.types.UnicodeText: DatabricksDataType.STRING, + sqlalchemy.types.Uuid: DatabricksDataType.STRING, } -# Convert the dictionary into a list of tuples for use in pytest.mark.parametrize -_as_tuple_list = [(key, value) for key, value in camel_case_type_map.items()] + +def dict_as_tuple_list(d: dict): + """Return a list of [(key, value), ...] from a dictionary.""" + return [(key, value) for key, value in d.items()] class CompilationTestBase: dialect = DatabricksDialect() - def _assert_compiled_value(self, type_: TypeEngine, expected: DatabricksDataType): + def _assert_compiled_value( + self, type_: sqlalchemy.types.TypeEngine, expected: DatabricksDataType + ): """Assert that when type_ is compiled for the databricks dialect, it renders the DatabricksDataType name. This method initialises the type_ with no arguments. @@ -91,7 +72,9 @@ def _assert_compiled_value(self, type_: TypeEngine, expected: DatabricksDataType compiled_result = type_().compile(dialect=self.dialect) # type: ignore assert compiled_result == expected.name - def _assert_compiled_value_explicit(self, type_: TypeEngine, expected: str): + def _assert_compiled_value_explicit( + self, type_: sqlalchemy.types.TypeEngine, expected: str + ): """Assert that when type_ is compiled for the databricks dialect, it renders the expected string. This method expects an initialised type_ so that we can test how a TypeEngine created with arguments @@ -117,12 +100,47 @@ class TestCamelCaseTypesCompilation(CompilationTestBase): [1]: https://docs.sqlalchemy.org/en/20/core/type_basics.html#generic-camelcase-types """ - @pytest.mark.parametrize("type_, expected", _as_tuple_list) + @pytest.mark.parametrize("type_, expected", dict_as_tuple_list(camel_case_type_map)) def test_bare_camel_case_types_compile(self, type_, expected): self._assert_compiled_value(type_, expected) def test_numeric_renders_as_decimal_with_precision(self): - self._assert_compiled_value_explicit(Numeric(10), "DECIMAL(10)") + self._assert_compiled_value_explicit( + sqlalchemy.types.Numeric(10), "DECIMAL(10)" + ) def test_numeric_renders_as_decimal_with_precision_and_scale(self): - self._assert_compiled_value_explicit(Numeric(10, 2), "DECIMAL(10, 2)") + return self._assert_compiled_value_explicit( + sqlalchemy.types.Numeric(10, 2), "DECIMAL(10, 2)" + ) + + +uppercase_type_map = { + sqlalchemy.types.ARRAY: DatabricksDataType.ARRAY, + sqlalchemy.types.BIGINT: DatabricksDataType.BIGINT, + sqlalchemy.types.BINARY: DatabricksDataType.BINARY, + sqlalchemy.types.BOOLEAN: DatabricksDataType.BOOLEAN, + sqlalchemy.types.DATE: DatabricksDataType.DATE, + sqlalchemy.types.DECIMAL: DatabricksDataType.DECIMAL, + sqlalchemy.types.DOUBLE: DatabricksDataType.DOUBLE, + sqlalchemy.types.FLOAT: DatabricksDataType.FLOAT, + sqlalchemy.types.INT: DatabricksDataType.INT, + sqlalchemy.types.SMALLINT: DatabricksDataType.SMALLINT, + sqlalchemy.types.TIMESTAMP: DatabricksDataType.TIMESTAMP, +} + + +class TestUppercaseTypesCompilation(CompilationTestBase): + """Per the sqlalchemy documentation[^1], uppercase types are considered to be specific to some + database backends. These tests verify that the types compile into valid Databricks SQL type strings. + + [1]: https://docs.sqlalchemy.org/en/20/core/type_basics.html#backend-specific-uppercase-datatypes + """ + + @pytest.mark.parametrize("type_, expected", dict_as_tuple_list(uppercase_type_map)) + def test_bare_uppercase_types_compile(self, type_, expected): + if isinstance(type_, type(sqlalchemy.types.ARRAY)): + # ARRAY cannot be initialised without passing an item definition so we test separately + # I preserve it in the uppercase_type_map for clarity + return True + return self._assert_compiled_value(type_, expected) diff --git a/src/databricks/sqlalchemy/types.py b/src/databricks/sqlalchemy/types.py index c6483d540..498e9adaf 100644 --- a/src/databricks/sqlalchemy/types.py +++ b/src/databricks/sqlalchemy/types.py @@ -1,27 +1,14 @@ +import sqlalchemy from sqlalchemy.ext.compiler import compiles -from sqlalchemy.sql.compiler import GenericTypeCompiler -from sqlalchemy.types import ( - DateTime, - Enum, - Integer, - LargeBinary, - Numeric, - String, - Text, - Time, - Unicode, - UnicodeText, - Uuid, -) -@compiles(Enum, "databricks") -@compiles(String, "databricks") -@compiles(Text, "databricks") -@compiles(Time, "databricks") -@compiles(Unicode, "databricks") -@compiles(UnicodeText, "databricks") -@compiles(Uuid, "databricks") +@compiles(sqlalchemy.types.Enum, "databricks") +@compiles(sqlalchemy.types.String, "databricks") +@compiles(sqlalchemy.types.Text, "databricks") +@compiles(sqlalchemy.types.Time, "databricks") +@compiles(sqlalchemy.types.Unicode, "databricks") +@compiles(sqlalchemy.types.UnicodeText, "databricks") +@compiles(sqlalchemy.types.Uuid, "databricks") def compile_string_databricks(type_, compiler, **kw): """ We override the default compilation for Enum(), String(), Text(), and Time() because SQLAlchemy @@ -40,7 +27,7 @@ def compile_string_databricks(type_, compiler, **kw): return "STRING" -@compiles(Integer, "databricks") +@compiles(sqlalchemy.types.Integer, "databricks") def compile_integer_databricks(type_, compiler, **kw): """ We need to override the default Integer compilation rendering because Databricks uses "INT" instead of "INTEGER" @@ -48,7 +35,7 @@ def compile_integer_databricks(type_, compiler, **kw): return "INT" -@compiles(LargeBinary, "databricks") +@compiles(sqlalchemy.types.LargeBinary, "databricks") def compile_binary_databricks(type_, compiler, **kw): """ We need to override the default LargeBinary compilation rendering because Databricks uses "BINARY" instead of "BLOB" @@ -56,7 +43,7 @@ def compile_binary_databricks(type_, compiler, **kw): return "BINARY" -@compiles(Numeric, "databricks") +@compiles(sqlalchemy.types.Numeric, "databricks") def compile_numeric_databricks(type_, compiler, **kw): """ We need to override the default Numeric compilation rendering because Databricks uses "DECIMAL" instead of "NUMERIC" @@ -67,7 +54,7 @@ def compile_numeric_databricks(type_, compiler, **kw): return compiler.visit_DECIMAL(type_, **kw) -@compiles(DateTime, "databricks") +@compiles(sqlalchemy.types.DateTime, "databricks") def compile_datetime_databricks(type_, compiler, **kw): """ We need to override the default DateTime compilation rendering because Databricks uses "TIMESTAMP" instead of "DATETIME" From 4771b1f6158c268e31935e4ec72abe613457fdfb Mon Sep 17 00:00:00 2001 From: Jesse Whitehouse Date: Mon, 2 Oct 2023 13:07:47 -0400 Subject: [PATCH 2/2] (2/x) Add compilation support for ARRAY with a test Signed-off-by: Jesse Whitehouse --- .../sqlalchemy/test_local/test_types.py | 10 ++++++++++ src/databricks/sqlalchemy/types.py | 18 ++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/src/databricks/sqlalchemy/test_local/test_types.py b/src/databricks/sqlalchemy/test_local/test_types.py index 06415c2ba..91f11e17e 100644 --- a/src/databricks/sqlalchemy/test_local/test_types.py +++ b/src/databricks/sqlalchemy/test_local/test_types.py @@ -144,3 +144,13 @@ def test_bare_uppercase_types_compile(self, type_, expected): # I preserve it in the uppercase_type_map for clarity return True return self._assert_compiled_value(type_, expected) + + def test_array_string_renders_as_array_of_string(self): + """SQLAlchemy's ARRAY type requires an item definition. And their docs indicate that they've only tested + it with Postgres since that's the only first-class dialect with support for ARRAY. + + https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.ARRAY + """ + return self._assert_compiled_value_explicit( + sqlalchemy.types.ARRAY(sqlalchemy.types.String), "ARRAY" + ) diff --git a/src/databricks/sqlalchemy/types.py b/src/databricks/sqlalchemy/types.py index 498e9adaf..4b10fc6f1 100644 --- a/src/databricks/sqlalchemy/types.py +++ b/src/databricks/sqlalchemy/types.py @@ -60,3 +60,21 @@ def compile_datetime_databricks(type_, compiler, **kw): We need to override the default DateTime compilation rendering because Databricks uses "TIMESTAMP" instead of "DATETIME" """ return "TIMESTAMP" + + +@compiles(sqlalchemy.types.ARRAY, "databricks") +def compile_array_databricks(type_, compiler, **kw): + """ + SQLAlchemy's default ARRAY can't compile as it's only implemented for Postgresql. + The Postgres implementation works for Databricks SQL, so we duplicate that here. + + :type_: + This is an instance of sqlalchemy.types.ARRAY which always includes an item_type attribute + which is itself an instance of TypeEngine + + https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.ARRAY + """ + + inner = compiler.process(type_.item_type, **kw) + + return f"ARRAY<{inner}>"