diff --git a/src/databricks/sqlalchemy/dialect/__init__.py b/src/databricks/sqlalchemy/dialect/__init__.py index da508bb09..17d539d8a 100644 --- a/src/databricks/sqlalchemy/dialect/__init__.py +++ b/src/databricks/sqlalchemy/dialect/__init__.py @@ -65,6 +65,18 @@ def adapt(self, impltype, **kwargs): return self.impl +class TINYINT(types.Integer): + """ + A one-byte signed integer. Can represent any integer number in the range from -128 to 127. + + Implementation copied from mssql dialect + + Details about this type: https://docs.databricks.com/sql/language-manual/data-types/tinyint-type.html + """ + + __visit_name__ = "TINYINT" + + class DatabricksDialect(default.DefaultDialect): """This dialect implements only those methods required to pass our e2e tests""" @@ -136,6 +148,7 @@ def get_columns(self, connection, table_name, schema=None, **kwargs): _type_map = { "boolean": types.Boolean, "smallint": types.SmallInteger, + "tinyint": TINYINT, "int": types.Integer, "bigint": types.BigInteger, "float": types.Float, diff --git a/src/databricks/sqlalchemy/dialect/compiler.py b/src/databricks/sqlalchemy/dialect/compiler.py index f77807ed4..01af0ee19 100644 --- a/src/databricks/sqlalchemy/dialect/compiler.py +++ b/src/databricks/sqlalchemy/dialect/compiler.py @@ -36,3 +36,6 @@ def visit_DATE(self, type_): def visit_DATETIME(self, type_): return "TIMESTAMP" + + def visit_TINYINT(self, type_): + return "TINYINT" diff --git a/tests/e2e/sqlalchemy/test_basic.py b/tests/e2e/sqlalchemy/test_basic.py index 4f4df91b6..95fc6b9ad 100644 --- a/tests/e2e/sqlalchemy/test_basic.py +++ b/tests/e2e/sqlalchemy/test_basic.py @@ -4,6 +4,7 @@ from sqlalchemy import create_engine, select, insert, Column, MetaData, Table from sqlalchemy.orm import declarative_base, Session from sqlalchemy.types import SMALLINT, Integer, BOOLEAN, String, DECIMAL, Date +from databricks.sqlalchemy.dialect import TINYINT USER_AGENT_TOKEN = "PySQL e2e Tests" @@ -150,12 +151,13 @@ def test_create_insert_drop_table_core(base, db_engine, metadata_obj: MetaData): Column("episodes", Integer), Column("some_bool", BOOLEAN), Column("dollars", DECIMAL(10, 2)), + Column("tiny_int", TINYINT) ) metadata_obj.create_all() insert_stmt = insert(SampleTable).values( - name="Bim Adewunmi", episodes=6, some_bool=True, dollars=decimal.Decimal(125) + name="Bim Adewunmi", episodes=6, some_bool=True, dollars=decimal.Decimal(125), tiny_int=25 ) with db_engine.connect() as conn: