From d947102b9e6f5a8dc29abcee19403ed7ec068951 Mon Sep 17 00:00:00 2001 From: Walt Askew Date: Tue, 22 Apr 2025 00:34:45 +0000 Subject: [PATCH 1/3] feat: support named schemas Add support for SELECT, UPDATE and DELETE statements against tables in schemas. Schema names are not allowed in Spanner SELECT statements. We need to avoid generating SQL like ```sql SELECT schema.tbl.id FROM schema.tbl ``` To do so, we alias the table in order to produce SQL like: ```sql SELECT tbl_1.id, tbl_1.col FROM schema.tbl AS tbl_1 ``` --- README.rst | 1 - .../sqlalchemy_spanner/sqlalchemy_spanner.py | 85 ++++++++++++++++++- test/system/test_basics.py | 50 +++++++++++ 3 files changed, 134 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index 927848dc..2a195455 100644 --- a/README.rst +++ b/README.rst @@ -465,7 +465,6 @@ Other limitations ~~~~~~~~~~~~~~~~~ - WITH RECURSIVE statement is not supported. -- Named schemas are not supported. - Temporary tables are not supported. - Numeric type dimensions (scale and precision) are constant. See the `docs `__. diff --git a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py index 9670327f..fa3242cf 100644 --- a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py +++ b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py @@ -233,6 +233,10 @@ class SpannerSQLCompiler(SQLCompiler): compound_keywords = _compound_keywords + def __init__(self, *args, **kwargs): + self.tablealiases = {} + super().__init__(*args, **kwargs) + def get_from_hint_text(self, _, text): """Return a hint text. @@ -378,8 +382,10 @@ def limit_clause(self, select, **kw): return text def returning_clause(self, stmt, returning_cols, **kw): + # Set include_table=False because although table names are allowed in + # RETURNING clauses, schema names are not. columns = [ - self._label_select_column(None, c, True, False, {}) + self._label_select_column(None, c, True, False, {}, include_table=False) for c in expression._select_iterables(returning_cols) ] @@ -391,6 +397,83 @@ def visit_sequence(self, seq, **kw): seq ) + def visit_table(self, table, spanner_aliased=False, iscrud=False, **kwargs): + """Produces the table name. + + Schema names are not allowed in Spanner SELECT statements. We + need to avoid generating SQL like + + SELECT schema.tbl.id + FROM schema.tbl + + To do so, we alias the table in order to produce SQL like: + + SELECT tbl_1.id, tbl_1.col + FROM schema.tbl AS tbl_1 + + And do similar for UPDATE and DELETE statements. + + We don't need to correct INSERT statements, which is fortunate + because INSERT statements actually do not currently result in + calls to `visit_table`. + + This closely mirrors the mssql dialect which also avoids + schema-qualified columns in SELECTs, although the behaviour is + currently behind a deprecated 'legacy_schema_aliasing' flag. + """ + if spanner_aliased is table or self.isinsert: + return super().visit_table(table, **kwargs) + + # alias schema-qualified tables + alias = self._schema_aliased_table(table) + if alias is not None: + return self.process(alias, spanner_aliased=table, **kwargs) + else: + return super().visit_table(table, **kwargs) + + def visit_alias(self, alias, **kw): + """Produces alias statements.""" + # translate for schema-qualified table aliases + kw["spanner_aliased"] = alias.element + return super().visit_alias(alias, **kw) + + def visit_column(self, column, add_to_result_map=None, **kw): + """Produces column expressions. + + In tandem with visit_table, replaces schema-qualified column + names with column names qualified against an alias. + """ + if column.table is not None and not self.isinsert or self.is_subquery(): + # translate for schema-qualified table aliases + t = self._schema_aliased_table(column.table) + if t is not None: + converted = elements._corresponding_column_or_error(t, column) + if add_to_result_map is not None: + add_to_result_map( + column.name, + column.name, + (column, column.name, column.key), + column.type, + ) + + return super().visit_column(converted, **kw) + + return super().visit_column(column, add_to_result_map=add_to_result_map, **kw) + + def _schema_aliased_table(self, table): + """Creates an alias for the table if it is schema-qualified. + + If the table is schema-qualified, returns an alias for the + table and caches the alias for future references to the + table. If the table is not schema-qualified, returns None. + """ + if getattr(table, "schema", None) is not None: + if table not in self.tablealiases: + self.tablealiases[table] = table.alias() + return self.tablealiases[table] + else: + return None + class SpannerDDLCompiler(DDLCompiler): """Spanner DDL statements compiler.""" diff --git a/test/system/test_basics.py b/test/system/test_basics.py index e5411988..3001052d 100644 --- a/test/system/test_basics.py +++ b/test/system/test_basics.py @@ -25,6 +25,8 @@ Boolean, BIGINT, select, + update, + delete, ) from sqlalchemy.orm import Session, DeclarativeBase, Mapped, mapped_column from sqlalchemy.types import REAL @@ -58,6 +60,16 @@ def define_tables(cls, metadata): Column("name", String(20)), ) + with cls.bind.begin() as conn: + conn.execute(text("CREATE SCHEMA IF NOT EXISTS schema")) + Table( + "users", + metadata, + Column("ID", Integer, primary_key=True), + Column("name", String(20)), + schema="schema", + ) + def test_hello_world(self, connection): greeting = connection.execute(text("select 'Hello World'")) eq_("Hello World", greeting.fetchone()[0]) @@ -139,6 +151,12 @@ class User(Base): ID: Mapped[int] = mapped_column(primary_key=True) name: Mapped[str] = mapped_column(String(20)) + class SchemaUser(Base): + __tablename__ = "users" + __table_args__ = {"schema": "schema"} + ID: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String(20)) + engine = connection.engine with Session(engine) as session: number = Number( @@ -156,3 +174,35 @@ class User(Base): users = session.scalars(statement).all() eq_(1, len(users)) is_true(users[0].ID > 0) + + with Session(engine) as session: + user = SchemaUser(name="SchemaTest") + session.add(user) + session.commit() + + users = session.scalars( + select(SchemaUser).where(SchemaUser.name == "SchemaTest") + ).all() + eq_(1, len(users)) + is_true(users[0].ID > 0) + + session.execute( + update(SchemaUser) + .where(SchemaUser.name == "SchemaTest") + .values(name="NewName") + ) + session.commit() + + users = session.scalars( + select(SchemaUser).where(SchemaUser.name == "NewName") + ).all() + eq_(1, len(users)) + is_true(users[0].ID > 0) + + session.execute(delete(SchemaUser).where(SchemaUser.name == "NewName")) + session.commit() + + users = session.scalars( + select(SchemaUser).where(SchemaUser.name == "NewName") + ).all() + eq_(0, len(users)) From 2638b77513b7d7606c6caac0c5641579c185f538 Mon Sep 17 00:00:00 2001 From: Walt Askew Date: Thu, 1 May 2025 21:53:04 +0000 Subject: [PATCH 2/3] feat: backwards-compatible include_table=False when returning In sqlalchemy 2.0, it's: ```python self._label_select_column(None, c, True, False, {}, include_table=False) ``` In older versions, it's ```python self._label_select_column(None, c, True, False, {'include_table': False}) ``` So instead, forward a flag to `vist_column` which can set this kwarg safely itself. --- .../sqlalchemy_spanner/sqlalchemy_spanner.py | 30 ++++++++++++++----- test/mockserver_tests/test_auto_increment.py | 4 +-- .../test_bit_reversed_sequence.py | 2 +- 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py index fa3242cf..30f43311 100644 --- a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py +++ b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py @@ -382,10 +382,11 @@ def limit_clause(self, select, **kw): return text def returning_clause(self, stmt, returning_cols, **kw): - # Set include_table=False because although table names are allowed in - # RETURNING clauses, schema names are not. + # Set the spanner_is_returning flag which is passed to visit_column. columns = [ - self._label_select_column(None, c, True, False, {}, include_table=False) + self._label_select_column( + None, c, True, False, {"spanner_is_returning": True} + ) for c in expression._select_iterables(returning_cols) ] @@ -413,10 +414,6 @@ def visit_table(self, table, spanner_aliased=False, iscrud=False, **kwargs): And do similar for UPDATE and DELETE statements. - We don't need to correct INSERT statements, which is fortunate - because INSERT statements actually do not currently result in - calls to `visit_table`. - This closely mirrors the mssql dialect which also avoids schema-qualified columns in SELECTs, although the behaviour is currently behind a deprecated 'legacy_schema_aliasing' flag. @@ -437,7 +434,9 @@ def visit_alias(self, alias, **kw): kw["spanner_aliased"] = alias.element return super().visit_alias(alias, **kw) - def visit_column(self, column, add_to_result_map=None, **kw): + def visit_column( + self, column, add_to_result_map=None, spanner_is_returning=False, **kw + ): """Produces column expressions. In tandem with visit_table, replaces schema-qualified column @@ -457,6 +456,21 @@ def visit_column(self, column, add_to_result_map=None, **kw): ) return super().visit_column(converted, **kw) + if spanner_is_returning: + # Set include_table=False because although table names are + # allowed in RETURNING clauses, schema names are not. We + # can't use the same aliasing trick above that we use with + # other statements, because INSERT statements don't result + # in visit_table calls and INSERT table names can't be + # aliased. Statements like: + # + # INSERT INTO table (id, name) + # SELECT id, name FROM another_table + # THEN RETURN another_table.id + # + # aren't legal, so the columns remain unambiguous when not + # qualified by table name. + kw["include_table"] = False return super().visit_column(column, add_to_result_map=add_to_result_map, **kw) diff --git a/test/mockserver_tests/test_auto_increment.py b/test/mockserver_tests/test_auto_increment.py index 6bc5e2c0..7fa245e8 100644 --- a/test/mockserver_tests/test_auto_increment.py +++ b/test/mockserver_tests/test_auto_increment.py @@ -125,9 +125,7 @@ def test_create_table_with_specific_sequence_kind(self): def test_insert_row(self): from test.mockserver_tests.auto_increment_model import Singer - self.add_insert_result( - "INSERT INTO singers (name) VALUES (@a0) THEN RETURN singers.id" - ) + self.add_insert_result("INSERT INTO singers (name) VALUES (@a0) THEN RETURN id") engine = create_engine( "spanner:///projects/p/instances/i/databases/d", connect_args={"client": self.client, "pool": FixedSizePool(size=10)}, diff --git a/test/mockserver_tests/test_bit_reversed_sequence.py b/test/mockserver_tests/test_bit_reversed_sequence.py index a18bc08e..9e7a81a8 100644 --- a/test/mockserver_tests/test_bit_reversed_sequence.py +++ b/test/mockserver_tests/test_bit_reversed_sequence.py @@ -110,7 +110,7 @@ def test_insert_row(self): add_result( "INSERT INTO singers (id, name) " "VALUES ( GET_NEXT_SEQUENCE_VALUE(SEQUENCE singer_id), @a0) " - "THEN RETURN singers.id", + "THEN RETURN id", result, ) engine = create_engine( From e873ba18e84f4551cb92ed0988084342ee9a67da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Fri, 2 May 2025 15:10:18 +0200 Subject: [PATCH 3/3] test: add mock server tests to guarantee exact SQL generation --- .../sqlalchemy_spanner/sqlalchemy_spanner.py | 4 +- test/mockserver_tests/test_basics.py | 50 +++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py index e83589f7..e5559c65 100644 --- a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py +++ b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py @@ -421,7 +421,9 @@ def visit_table(self, table, spanner_aliased=False, iscrud=False, **kwargs): if spanner_aliased is table or self.isinsert: return super().visit_table(table, **kwargs) - # alias schema-qualified tables + # Add an alias for schema-qualified tables. + # Tables in the default schema are not aliased and follow the + # standard SQLAlchemy code path. alias = self._schema_aliased_table(table) if alias is not None: return self.process(alias, spanner_aliased=table, **kwargs) diff --git a/test/mockserver_tests/test_basics.py b/test/mockserver_tests/test_basics.py index 6db248d6..e1445829 100644 --- a/test/mockserver_tests/test_basics.py +++ b/test/mockserver_tests/test_basics.py @@ -262,3 +262,53 @@ class Singer(Base): singer.name = "New Name" session.add(singer) session.commit() + + def test_select_table_in_named_schema(self): + class Base(DeclarativeBase): + pass + + class Singer(Base): + __tablename__ = "singers" + __table_args__ = {"schema": "my_schema"} + id: Mapped[int] = mapped_column(BigInteger, primary_key=True) + name: Mapped[str] = mapped_column(String) + + query = ( + "SELECT" + " singers_1.id AS my_schema_singers_id," + " singers_1.name AS my_schema_singers_name\n" + "FROM my_schema.singers AS singers_1\n" + "WHERE singers_1.id = @a0\n" + " LIMIT @a1" + ) + add_singer_query_result(query) + engine = create_engine( + "spanner:///projects/p/instances/i/databases/d", + connect_args={"client": self.client, "pool": FixedSizePool(size=10)}, + ) + + insert = "INSERT INTO my_schema.singers (name) VALUES (@a0) THEN RETURN id" + add_single_result(insert, "id", TypeCode.INT64, [("1",)]) + with Session(engine) as session: + singer = Singer(name="New Name") + session.add(singer) + session.commit() + + update = ( + "UPDATE my_schema.singers AS singers_1 " + "SET name=@a0 " + "WHERE singers_1.id = @a1" + ) + add_update_count(update, 1) + with Session(engine) as session: + singer = session.query(Singer).filter(Singer.id == 1).first() + singer.name = "New Name" + session.add(singer) + session.commit() + + delete = "DELETE FROM my_schema.singers AS singers_1 WHERE singers_1.id = @a0" + add_update_count(delete, 1) + with Session(engine) as session: + singer = session.query(Singer).filter(Singer.id == 1).first() + session.delete(singer) + session.commit()