Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit ec28c1c

Browse filesBrowse files
committed
feat: adding unit tests for django spanner
1 parent 798e88d commit ec28c1c
Copy full SHA for ec28c1c

15 files changed

+1256
-8
lines changed

‎.gitignore

Copy file name to clipboardExpand all lines: .gitignore
+4Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,7 @@ django_tests_dir
2424

2525
# Built documentation
2626
docs/_build
27+
28+
# mac hidden files.
29+
.DS_Store
30+

‎django_spanner/lookups.py

Copy file name to clipboardExpand all lines: django_spanner/lookups.py
+4-1Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,10 @@ def iexact(self, compiler, connection):
101101
# lhs_sql is the expression/column to use as the regular expression.
102102
# Use concat to make the value case-insensitive.
103103
lhs_sql = "CONCAT('^(?i)', " + lhs_sql + ", '$')"
104-
rhs_sql = rhs_sql.replace("%%s", "%s")
104+
if not self.rhs_is_direct_value() and not params:
105+
# If rhs is not a direct value and parameter is not present we want
106+
# to have only 1 formatable argument in rhs_sql else we need 2.
107+
rhs_sql = rhs_sql.replace("%%s", "%s")
105108
# rhs_sql is REGEXP_CONTAINS(%s, %%s), and lhs_sql is the column name.
106109
return rhs_sql % lhs_sql, params
107110

‎noxfile.py

Copy file name to clipboardExpand all lines: noxfile.py
+7-2Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,12 @@ def lint_setup_py(session):
6666
def default(session):
6767
# Install all test dependencies, then install this package in-place.
6868
session.install(
69-
"django~=2.2", "mock", "mock-import", "pytest", "pytest-cov"
69+
"django~=2.2",
70+
"mock",
71+
"mock-import",
72+
"pytest",
73+
"pytest-cov",
74+
"coverage",
7075
)
7176
session.install("-e", ".")
7277

@@ -79,7 +84,7 @@ def default(session):
7984
"--cov-append",
8085
"--cov-config=.coveragerc",
8186
"--cov-report=",
82-
"--cov-fail-under=25",
87+
"--cov-fail-under=80",
8388
os.path.join("tests", "unit"),
8489
*session.posargs
8590
)

‎tests/settings.py

Copy file name to clipboard
+40Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
DEBUG = True
2+
USE_TZ = True
3+
4+
INSTALLED_APPS = [
5+
"django_spanner", # Must be the first entry
6+
"django.contrib.contenttypes",
7+
"django.contrib.auth",
8+
"django.contrib.sites",
9+
"django.contrib.sessions",
10+
"django.contrib.messages",
11+
"django.contrib.staticfiles",
12+
"tests",
13+
]
14+
15+
TIME_ZONE = "UTC"
16+
17+
DATABASES = {
18+
"default": {
19+
"ENGINE": "django_spanner",
20+
"PROJECT": "emulator-local",
21+
"INSTANCE": "django-test-instance",
22+
"NAME": "django-test-db",
23+
}
24+
}
25+
SECRET_KEY = "spanner emulator secret key"
26+
27+
PASSWORD_HASHERS = [
28+
"django.contrib.auth.hashers.MD5PasswordHasher",
29+
]
30+
31+
SITE_ID = 1
32+
33+
CONN_MAX_AGE = 60
34+
35+
ENGINE = "django_spanner"
36+
PROJECT = "emulator-local"
37+
INSTANCE = "django-test-instance"
38+
NAME = "django-test-db"
39+
OPTIONS = {}
40+
AUTOCOMMIT = True

‎tests/unit/django_spanner/__init__.py

Copy file name to clipboardExpand all lines: tests/unit/django_spanner/__init__.py
Whitespace-only changes.

‎tests/unit/django_spanner/models.py

Copy file name to clipboard
+64Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""
2+
Different models used for testing django-spanner code.
3+
"""
4+
import os
5+
from django.db import models
6+
import django
7+
from django.db.models import Transform
8+
from django.db.models import CharField, TextField
9+
10+
# Load django settings before loading dhango models.
11+
os.environ["DJANGO_SETTINGS_MODULE"] = "tests.settings"
12+
django.setup()
13+
14+
15+
# Register transformations for model fields.
16+
class UpperCase(Transform):
17+
lookup_name = "upper"
18+
function = "UPPER"
19+
bilateral = True
20+
21+
22+
CharField.register_lookup(UpperCase)
23+
TextField.register_lookup(UpperCase)
24+
25+
26+
# Models
27+
class ModelDecimalField(models.Model):
28+
field = models.DecimalField()
29+
30+
31+
class ModelCharField(models.Model):
32+
field = models.CharField()
33+
34+
35+
class Item(models.Model):
36+
item_id = models.IntegerField()
37+
name = models.CharField(max_length=10)
38+
created = models.DateTimeField()
39+
modified = models.DateTimeField(blank=True, null=True)
40+
41+
class Meta:
42+
ordering = ["name"]
43+
44+
45+
class Number(models.Model):
46+
num = models.IntegerField()
47+
decimal_num = models.DecimalField(max_digits=5, decimal_places=2)
48+
item = models.ForeignKey(Item, models.CASCADE)
49+
50+
51+
class Author(models.Model):
52+
name = models.CharField(max_length=40)
53+
last_name = models.CharField(max_length=40)
54+
num = models.IntegerField(unique=True)
55+
created = models.DateTimeField()
56+
modified = models.DateTimeField(blank=True, null=True)
57+
58+
59+
class Report(models.Model):
60+
name = models.CharField(max_length=10)
61+
creator = models.ForeignKey(Author, models.CASCADE, null=True)
62+
63+
class Meta:
64+
ordering = ["name"]

‎tests/unit/django_spanner/test_base.py

Copy file name to clipboardExpand all lines: tests/unit/django_spanner/test_base.py
+3-3Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def test_property_instance(self):
4949
_ = db_wrapper.instance
5050
mock_instance.assert_called_once_with(settings_dict["INSTANCE"])
5151

52-
def test_property__nodb_connection(self):
52+
def test_property_nodb_connection(self):
5353
db_wrapper = self._make_one(None)
5454
with self.assertRaises(NotImplementedError):
5555
db_wrapper._nodb_connection()
@@ -86,7 +86,7 @@ def test_create_cursor(self):
8686
db_wrapper.create_cursor()
8787
mock_cursor.assert_called_once_with()
8888

89-
def test__set_autocommit(self):
89+
def test_set_autocommit(self):
9090
db_wrapper = self._make_one(self.settings_dict)
9191
db_wrapper.connection = mock_connection = mock.MagicMock()
9292
mock_connection.autocommit = False
@@ -110,7 +110,7 @@ def test_is_usable(self):
110110
mock_connection.cursor = mock.MagicMock(side_effect=Error)
111111
self.assertFalse(db_wrapper.is_usable())
112112

113-
def test__start_transaction_under_autocommit(self):
113+
def test_start_transaction_under_autocommit(self):
114114
db_wrapper = self._make_one(self.settings_dict)
115115
db_wrapper.connection = mock_connection = mock.MagicMock()
116116
mock_connection.cursor = mock_cursor = mock.MagicMock()

‎tests/unit/django_spanner/test_client.py

Copy file name to clipboardExpand all lines: tests/unit/django_spanner/test_client.py
+1-2Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import sys
88
import unittest
99
import os
10+
from google.cloud.spanner_dbapi.exceptions import NotSupportedError
1011

1112

1213
@unittest.skipIf(
@@ -36,8 +37,6 @@ def _make_one(self, *args, **kwargs):
3637
return self._get_target_class()(*args, **kwargs)
3738

3839
def test_runshell(self):
39-
from google.cloud.spanner_dbapi.exceptions import NotSupportedError
40-
4140
db_wrapper = self._make_one(self.settings_dict)
4241

4342
with self.assertRaises(NotSupportedError):
+199Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Use of this source code is governed by a BSD-style
4+
# license that can be found in the LICENSE file or at
5+
# https://developers.google.com/open-source/licenses/bsd
6+
7+
import sys
8+
import unittest
9+
10+
from django.test import SimpleTestCase
11+
from django.core.exceptions import EmptyResultSet
12+
from django.db.utils import DatabaseError
13+
from django_spanner.compiler import SQLCompiler
14+
from django.db.models.query import QuerySet
15+
from .models import Number
16+
17+
18+
@unittest.skipIf(
19+
sys.version_info < (3, 6), reason="Skipping Python versions <= 3.5"
20+
)
21+
class TestUtils(SimpleTestCase):
22+
settings_dict = {"dummy_param": "dummy"}
23+
24+
def _get_target_class(self):
25+
from django_spanner.base import DatabaseWrapper
26+
27+
return DatabaseWrapper
28+
29+
def _make_one(self, *args, **kwargs):
30+
return self._get_target_class()(*args, **kwargs)
31+
32+
def test_unsupported_ordering_slicing_raises_db_error(self):
33+
"""
34+
Tries limit/offset and order by in subqueries which are not supported
35+
by spanner.
36+
"""
37+
qs1 = Number.objects.all()
38+
qs2 = Number.objects.all()
39+
msg = "LIMIT/OFFSET not allowed in subqueries of compound statements"
40+
with self.assertRaisesMessage(DatabaseError, msg):
41+
list(qs1.union(qs2[:10]))
42+
msg = "ORDER BY not allowed in subqueries of compound statements"
43+
with self.assertRaisesMessage(DatabaseError, msg):
44+
list(qs1.order_by("id").union(qs2))
45+
46+
def test_get_combinator_sql_all_union_sql_generated(self):
47+
"""
48+
Tries union sql generator.
49+
"""
50+
connection = self._make_one(self.settings_dict)
51+
52+
qs1 = Number.objects.filter(num__lte=1).values("num")
53+
qs2 = Number.objects.filter(num__gte=8).values("num")
54+
qs4 = qs1.union(qs2)
55+
56+
compiler = SQLCompiler(qs4.query, connection, "default")
57+
sql_compiled, params = compiler.get_combinator_sql("union", True)
58+
self.assertEqual(
59+
sql_compiled,
60+
[
61+
"SELECT tests_number.num FROM tests_number WHERE "
62+
+ "tests_number.num <= %s UNION ALL SELECT tests_number.num "
63+
+ "FROM tests_number WHERE tests_number.num >= %s"
64+
],
65+
)
66+
self.assertEqual(params, [1, 8])
67+
68+
def test_get_combinator_sql_distinct_union_sql_generated(self):
69+
"""
70+
Tries union sql generator with distinct.
71+
"""
72+
connection = self._make_one(self.settings_dict)
73+
74+
qs1 = Number.objects.filter(num__lte=1).values("num")
75+
qs2 = Number.objects.filter(num__gte=8).values("num")
76+
qs4 = qs1.union(qs2)
77+
78+
compiler = SQLCompiler(qs4.query, connection, "default")
79+
sql_compiled, params = compiler.get_combinator_sql("union", False)
80+
self.assertEqual(
81+
sql_compiled,
82+
[
83+
"SELECT tests_number.num FROM tests_number WHERE "
84+
+ "tests_number.num <= %s UNION DISTINCT SELECT "
85+
+ "tests_number.num FROM tests_number WHERE "
86+
+ "tests_number.num >= %s"
87+
],
88+
)
89+
self.assertEqual(params, [1, 8])
90+
91+
def test_get_combinator_sql_difference_all_sql_generated(self):
92+
"""
93+
Tries difference sql generator.
94+
"""
95+
connection = self._make_one(self.settings_dict)
96+
97+
qs1 = Number.objects.filter(num__lte=1).values("num")
98+
qs2 = Number.objects.filter(num__gte=8).values("num")
99+
qs4 = qs1.difference(qs2)
100+
101+
compiler = SQLCompiler(qs4.query, connection, "default")
102+
sql_compiled, params = compiler.get_combinator_sql("difference", True)
103+
104+
self.assertEqual(
105+
sql_compiled,
106+
[
107+
"SELECT tests_number.num FROM tests_number WHERE "
108+
+ "tests_number.num <= %s EXCEPT ALL SELECT tests_number.num "
109+
+ "FROM tests_number WHERE tests_number.num >= %s"
110+
],
111+
)
112+
self.assertEqual(params, [1, 8])
113+
114+
def test_get_combinator_sql_difference_distinct_sql_generated(self):
115+
"""
116+
Tries difference sql generator with distinct.
117+
"""
118+
connection = self._make_one(self.settings_dict)
119+
120+
qs1 = Number.objects.filter(num__lte=1).values("num")
121+
qs2 = Number.objects.filter(num__gte=8).values("num")
122+
qs4 = qs1.difference(qs2)
123+
124+
compiler = SQLCompiler(qs4.query, connection, "default")
125+
sql_compiled, params = compiler.get_combinator_sql("difference", False)
126+
127+
self.assertEqual(
128+
sql_compiled,
129+
[
130+
"SELECT tests_number.num FROM tests_number WHERE "
131+
+ "tests_number.num <= %s EXCEPT DISTINCT SELECT "
132+
+ "tests_number.num FROM tests_number WHERE "
133+
+ "tests_number.num >= %s"
134+
],
135+
)
136+
self.assertEqual(params, [1, 8])
137+
138+
def test_get_combinator_sql_union_and_difference_query_together(self):
139+
"""
140+
Tries sql generator with union of queryset with queryset of difference.
141+
"""
142+
connection = self._make_one(self.settings_dict)
143+
144+
qs1 = Number.objects.filter(num__lte=1).values("num")
145+
qs2 = Number.objects.filter(num__gte=8).values("num")
146+
qs3 = Number.objects.filter(num__exact=10).values("num")
147+
qs4 = qs1.union(qs2.difference(qs3))
148+
149+
compiler = SQLCompiler(qs4.query, connection, "default")
150+
sql_compiled, params = compiler.get_combinator_sql("union", False)
151+
self.assertEqual(
152+
sql_compiled,
153+
[
154+
"SELECT tests_number.num FROM tests_number WHERE "
155+
+ "tests_number.num <= %s UNION DISTINCT ("
156+
+ "SELECT tests_number.num FROM tests_number WHERE "
157+
+ "tests_number.num >= %s EXCEPT DISTINCT "
158+
+ "SELECT tests_number.num FROM tests_number "
159+
+ "WHERE tests_number.num = %s)"
160+
],
161+
)
162+
self.assertEqual(params, [1, 8, 10])
163+
164+
def test_get_combinator_sql_parentheses_in_compound_not_supported(self):
165+
"""
166+
Tries sql generator with union of queryset with queryset of difference,
167+
adding support for parentheses in compound sql statement.
168+
"""
169+
connection = self._make_one(self.settings_dict)
170+
171+
qs1 = Number.objects.filter(num__lte=1).values("num")
172+
qs2 = Number.objects.filter(num__gte=8).values("num")
173+
qs3 = Number.objects.filter(num__exact=10).values("num")
174+
qs4 = qs1.union(qs2.difference(qs3))
175+
176+
compiler = SQLCompiler(qs4.query, connection, "default")
177+
compiler.connection.features.supports_parentheses_in_compound = False
178+
sql_compiled, params = compiler.get_combinator_sql("union", False)
179+
self.assertEqual(
180+
sql_compiled,
181+
[
182+
"SELECT tests_number.num FROM tests_number WHERE "
183+
+ "tests_number.num <= %s UNION DISTINCT SELECT * FROM ("
184+
+ "SELECT tests_number.num FROM tests_number WHERE "
185+
+ "tests_number.num >= %s EXCEPT DISTINCT "
186+
+ "SELECT tests_number.num FROM tests_number "
187+
+ "WHERE tests_number.num = %s)"
188+
],
189+
)
190+
self.assertEqual(params, [1, 8, 10])
191+
192+
def test_get_combinator_sql_empty_queryset_raises_exception(self):
193+
"""
194+
Tries sql generator with empty queryset.
195+
"""
196+
connection = self._make_one(self.settings_dict)
197+
compiler = SQLCompiler(QuerySet().query, connection, "default")
198+
with self.assertRaises(EmptyResultSet):
199+
compiler.get_combinator_sql("union", False)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.