Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

feat: support schemas in queries and dml statements #639

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion 1 README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,6 @@ Other limitations
~~~~~~~~~~~~~~~~~

- WITH RECURSIVE statement is not supported.
- Named schemas are not supported.
- Temporary tables are not supported.
- Numeric type dimensions (scale and precision) are constant. See the
`docs <https://cloud.google.com/spanner/docs/data-types#numeric_types>`__.
Expand Down
101 changes: 100 additions & 1 deletion 101 google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,10 @@ class SpannerSQLCompiler(SQLCompiler):

compound_keywords = _compound_keywords

def __init__(self, *args, **kwargs):
self.tablealiases = {}
super().__init__(*args, **kwargs)

def get_from_hint_text(self, _, text):
"""Return a hint text.

Expand Down Expand Up @@ -378,8 +382,11 @@ def limit_clause(self, select, **kw):
return text

def returning_clause(self, stmt, returning_cols, **kw):
# Set the spanner_is_returning flag which is passed to visit_column.
columns = [
self._label_select_column(None, c, True, False, {})
self._label_select_column(
None, c, True, False, {"spanner_is_returning": True}
)
for c in expression._select_iterables(returning_cols)
]

Expand All @@ -391,6 +398,98 @@ def visit_sequence(self, seq, **kw):
seq
)

def visit_table(self, table, spanner_aliased=False, iscrud=False, **kwargs):
"""Produces the table name.

Schema names are not allowed in Spanner SELECT statements. We
need to avoid generating SQL like

SELECT schema.tbl.id
FROM schema.tbl

To do so, we alias the table in order to produce SQL like:

SELECT tbl_1.id, tbl_1.col
FROM schema.tbl AS tbl_1

And do similar for UPDATE and DELETE statements.

This closely mirrors the mssql dialect which also avoids
schema-qualified columns in SELECTs, although the behaviour is
currently behind a deprecated 'legacy_schema_aliasing' flag.
"""
if spanner_aliased is table or self.isinsert:
return super().visit_table(table, **kwargs)

# Add an alias for schema-qualified tables.
# Tables in the default schema are not aliased and follow the
# standard SQLAlchemy code path.
alias = self._schema_aliased_table(table)
if alias is not None:
return self.process(alias, spanner_aliased=table, **kwargs)
else:
return super().visit_table(table, **kwargs)

def visit_alias(self, alias, **kw):
"""Produces alias statements."""
# translate for schema-qualified table aliases
kw["spanner_aliased"] = alias.element
return super().visit_alias(alias, **kw)

def visit_column(
self, column, add_to_result_map=None, spanner_is_returning=False, **kw
):
"""Produces column expressions.

In tandem with visit_table, replaces schema-qualified column
names with column names qualified against an alias.
"""
if column.table is not None and not self.isinsert or self.is_subquery():
# translate for schema-qualified table aliases
t = self._schema_aliased_table(column.table)
if t is not None:
converted = elements._corresponding_column_or_error(t, column)
if add_to_result_map is not None:
add_to_result_map(
column.name,
column.name,
(column, column.name, column.key),
column.type,
)

return super().visit_column(converted, **kw)
if spanner_is_returning:
# Set include_table=False because although table names are
# allowed in RETURNING clauses, schema names are not. We
# can't use the same aliasing trick above that we use with
# other statements, because INSERT statements don't result
# in visit_table calls and INSERT table names can't be
# aliased. Statements like:
#
# INSERT INTO table (id, name)
# SELECT id, name FROM another_table
# THEN RETURN another_table.id
#
# aren't legal, so the columns remain unambiguous when not
# qualified by table name.
kw["include_table"] = False

return super().visit_column(column, add_to_result_map=add_to_result_map, **kw)

def _schema_aliased_table(self, table):
"""Creates an alias for the table if it is schema-qualified.

If the table is schema-qualified, returns an alias for the
table and caches the alias for future references to the
table. If the table is not schema-qualified, returns None.
"""
if getattr(table, "schema", None) is not None:
if table not in self.tablealiases:
self.tablealiases[table] = table.alias()
return self.tablealiases[table]
else:
return None


class SpannerDDLCompiler(DDLCompiler):
"""Spanner DDL statements compiler."""
Expand Down
4 changes: 1 addition & 3 deletions 4 test/mockserver_tests/test_auto_increment.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,7 @@ def test_create_table_with_specific_sequence_kind(self):
def test_insert_row(self):
from test.mockserver_tests.auto_increment_model import Singer

self.add_insert_result(
"INSERT INTO singers (name) VALUES (@a0) THEN RETURN singers.id"
)
self.add_insert_result("INSERT INTO singers (name) VALUES (@a0) THEN RETURN id")
engine = create_engine(
"spanner:///projects/p/instances/i/databases/d",
connect_args={"client": self.client, "pool": FixedSizePool(size=10)},
Expand Down
50 changes: 50 additions & 0 deletions 50 test/mockserver_tests/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,3 +262,53 @@ class Singer(Base):
singer.name = "New Name"
session.add(singer)
session.commit()

def test_select_table_in_named_schema(self):
class Base(DeclarativeBase):
pass

class Singer(Base):
__tablename__ = "singers"
__table_args__ = {"schema": "my_schema"}
id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
name: Mapped[str] = mapped_column(String)

query = (
"SELECT"
" singers_1.id AS my_schema_singers_id,"
" singers_1.name AS my_schema_singers_name\n"
"FROM my_schema.singers AS singers_1\n"
"WHERE singers_1.id = @a0\n"
" LIMIT @a1"
)
add_singer_query_result(query)
engine = create_engine(
"spanner:///projects/p/instances/i/databases/d",
connect_args={"client": self.client, "pool": FixedSizePool(size=10)},
)

insert = "INSERT INTO my_schema.singers (name) VALUES (@a0) THEN RETURN id"
add_single_result(insert, "id", TypeCode.INT64, [("1",)])
with Session(engine) as session:
singer = Singer(name="New Name")
session.add(singer)
session.commit()

update = (
"UPDATE my_schema.singers AS singers_1 "
"SET name=@a0 "
"WHERE singers_1.id = @a1"
)
add_update_count(update, 1)
with Session(engine) as session:
singer = session.query(Singer).filter(Singer.id == 1).first()
singer.name = "New Name"
session.add(singer)
session.commit()

delete = "DELETE FROM my_schema.singers AS singers_1 WHERE singers_1.id = @a0"
add_update_count(delete, 1)
with Session(engine) as session:
singer = session.query(Singer).filter(Singer.id == 1).first()
session.delete(singer)
session.commit()
2 changes: 1 addition & 1 deletion 2 test/mockserver_tests/test_bit_reversed_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_insert_row(self):
add_result(
"INSERT INTO singers (id, name) "
"VALUES ( GET_NEXT_SEQUENCE_VALUE(SEQUENCE singer_id), @a0) "
"THEN RETURN singers.id",
"THEN RETURN id",
result,
)
engine = create_engine(
Expand Down
50 changes: 50 additions & 0 deletions 50 test/system/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
Boolean,
BIGINT,
select,
update,
delete,
)
from sqlalchemy.orm import Session, DeclarativeBase, Mapped, mapped_column
from sqlalchemy.types import REAL
Expand Down Expand Up @@ -58,6 +60,16 @@ def define_tables(cls, metadata):
Column("name", String(20)),
)

with cls.bind.begin() as conn:
conn.execute(text("CREATE SCHEMA IF NOT EXISTS schema"))
Table(
"users",
metadata,
Column("ID", Integer, primary_key=True),
Column("name", String(20)),
schema="schema",
)

def test_hello_world(self, connection):
greeting = connection.execute(text("select 'Hello World'"))
eq_("Hello World", greeting.fetchone()[0])
Expand Down Expand Up @@ -139,6 +151,12 @@ class User(Base):
ID: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(20))

class SchemaUser(Base):
__tablename__ = "users"
__table_args__ = {"schema": "schema"}
ID: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(20))

engine = connection.engine
with Session(engine) as session:
number = Number(
Expand All @@ -156,3 +174,35 @@ class User(Base):
users = session.scalars(statement).all()
eq_(1, len(users))
is_true(users[0].ID > 0)

with Session(engine) as session:
user = SchemaUser(name="SchemaTest")
session.add(user)
session.commit()

users = session.scalars(
select(SchemaUser).where(SchemaUser.name == "SchemaTest")
).all()
eq_(1, len(users))
is_true(users[0].ID > 0)

session.execute(
update(SchemaUser)
.where(SchemaUser.name == "SchemaTest")
.values(name="NewName")
)
session.commit()

users = session.scalars(
select(SchemaUser).where(SchemaUser.name == "NewName")
).all()
eq_(1, len(users))
is_true(users[0].ID > 0)

session.execute(delete(SchemaUser).where(SchemaUser.name == "NewName"))
session.commit()

users = session.scalars(
select(SchemaUser).where(SchemaUser.name == "NewName")
).all()
eq_(0, len(users))
Morty Proxy This is a proxified and sanitized view of the page, visit original site.