Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit c3318e8

Browse filesBrowse files
committed
Merge pull request PyMySQL#83 from methane/fix/executemany-double-percent
Port executemany() implementation from PyMySQL
2 parents e76b691 + 57dd34d commit c3318e8
Copy full SHA for c3318e8

File tree

Expand file treeCollapse file tree

3 files changed

+184
-78
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+184
-78
lines changed

‎MySQLdb/connections.py

Copy file name to clipboardExpand all lines: MySQLdb/connections.py
+4-1Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ def numeric_part(s):
6565

6666

6767
class Connection(_mysql.connection):
68-
6968
"""MySQL Database Connection Object"""
7069

7170
default_cursor = cursors.Cursor
@@ -278,6 +277,9 @@ def cursor(self, cursorclass=None):
278277
return (cursorclass or self.cursorclass)(self)
279278

280279
def query(self, query):
280+
# Since _mysql releases GIL while querying, we need immutable buffer.
281+
if isinstance(query, bytearray):
282+
query = bytes(query)
281283
if self.waiter is not None:
282284
self.send_query(query)
283285
self.waiter(self.fileno())
@@ -353,6 +355,7 @@ def set_character_set(self, charset):
353355
self.store_result()
354356
self.string_decoder.charset = py_charset
355357
self.unicode_literal.charset = py_charset
358+
self.encoding = py_charset
356359

357360
def set_sql_mode(self, sql_mode):
358361
"""Set the connection sql_mode. See MySQL documentation for

‎MySQLdb/cursors.py

Copy file name to clipboardExpand all lines: MySQLdb/cursors.py
+106-77Lines changed: 106 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -2,45 +2,34 @@
22
33
This module implements Cursors of various types for MySQLdb. By
44
default, MySQLdb uses the Cursor class.
5-
65
"""
7-
6+
from __future__ import print_function, absolute_import
7+
from functools import partial
88
import re
99
import sys
10-
PY2 = sys.version_info[0] == 2
1110

1211
from MySQLdb.compat import unicode
12+
from _mysql_exceptions import (
13+
Warning, Error, InterfaceError, DataError,
14+
DatabaseError, OperationalError, IntegrityError, InternalError,
15+
NotSupportedError, ProgrammingError)
1316

14-
restr = r"""
15-
\s
16-
values
17-
\s*
18-
(
19-
\(
20-
[^()']*
21-
(?:
22-
(?:
23-
(?:\(
24-
# ( - editor highlighting helper
25-
.*
26-
\))
27-
|
28-
'
29-
[^\\']*
30-
(?:\\.[^\\']*)*
31-
'
32-
)
33-
[^()']*
34-
)*
35-
\)
36-
)
37-
"""
3817

39-
insert_values = re.compile(restr, re.S | re.I | re.X)
18+
PY2 = sys.version_info[0] == 2
19+
if PY2:
20+
text_type = unicode
21+
else:
22+
text_type = str
23+
4024

41-
from _mysql_exceptions import Warning, Error, InterfaceError, DataError, \
42-
DatabaseError, OperationalError, IntegrityError, InternalError, \
43-
NotSupportedError, ProgrammingError
25+
#: Regular expression for :meth:`Cursor.executemany`.
26+
#: executemany only suports simple bulk insert.
27+
#: You can use it to load large dataset.
28+
RE_INSERT_VALUES = re.compile(
29+
r"\s*((?:INSERT|REPLACE)\s.+\sVALUES?\s+)" +
30+
r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))" +
31+
r"(\s*(?:ON DUPLICATE.*)?)\Z",
32+
re.IGNORECASE | re.DOTALL)
4433

4534

4635
class BaseCursor(object):
@@ -60,6 +49,12 @@ class BaseCursor(object):
6049
default number of rows fetchmany() will fetch
6150
"""
6251

52+
#: Max stetement size which :meth:`executemany` generates.
53+
#:
54+
#: Max size of allowed statement is max_allowed_packet - packet_header_size.
55+
#: Default value of max_allowed_packet is 1048576.
56+
max_stmt_length = 64*1024
57+
6358
from _mysql_exceptions import MySQLError, Warning, Error, InterfaceError, \
6459
DatabaseError, DataError, OperationalError, IntegrityError, \
6560
InternalError, ProgrammingError, NotSupportedError
@@ -102,6 +97,32 @@ def __exit__(self, *exc_info):
10297
del exc_info
10398
self.close()
10499

100+
def _ensure_bytes(self, x, encoding=None):
101+
if isinstance(x, text_type):
102+
x = x.encode(encoding)
103+
elif isinstance(x, (tuple, list)):
104+
x = type(x)(self._ensure_bytes(v, encoding=encoding) for v in x)
105+
return x
106+
107+
def _escape_args(self, args, conn):
108+
ensure_bytes = partial(self._ensure_bytes, encoding=conn.encoding)
109+
110+
if isinstance(args, (tuple, list)):
111+
if PY2:
112+
args = tuple(map(ensure_bytes, args))
113+
return tuple(conn.literal(arg) for arg in args)
114+
elif isinstance(args, dict):
115+
if PY2:
116+
args = dict((ensure_bytes(key), ensure_bytes(val)) for
117+
(key, val) in args.items())
118+
return dict((key, conn.literal(val)) for (key, val) in args.items())
119+
else:
120+
# If it's not a dictionary let's try escaping it anyways.
121+
# Worst case it will throw a Value error
122+
if PY2:
123+
args = ensure_bytes(args)
124+
return conn.literal(args)
125+
105126
def _check_executed(self):
106127
if not self._executed:
107128
self.errorhandler(self, ProgrammingError, "execute() first")
@@ -230,62 +251,70 @@ def execute(self, query, args=None):
230251
return res
231252

232253
def executemany(self, query, args):
254+
# type: (str, list) -> int
233255
"""Execute a multi-row query.
234256
235-
query -- string, query to execute on server
236-
237-
args
238-
239-
Sequence of sequences or mappings, parameters to use with
240-
query.
241-
242-
Returns long integer rows affected, if any.
257+
:param query: query to execute on server
258+
:param args: Sequence of sequences or mappings. It is used as parameter.
259+
:return: Number of rows affected, if any.
243260
244261
This method improves performance on multiple-row INSERT and
245262
REPLACE. Otherwise it is equivalent to looping over args with
246263
execute().
247264
"""
248265
del self.messages[:]
249-
db = self._get_db()
250-
if not args: return
251-
if PY2 and isinstance(query, unicode):
252-
query = query.encode(db.unicode_literal.charset)
253-
elif not PY2 and isinstance(query, bytes):
254-
query = query.decode(db.unicode_literal.charset)
255-
m = insert_values.search(query)
256-
if not m:
257-
r = 0
258-
for a in args:
259-
r = r + self.execute(query, a)
260-
return r
261-
p = m.start(1)
262-
e = m.end(1)
263-
qv = m.group(1)
264-
try:
265-
q = []
266-
for a in args:
267-
if isinstance(a, dict):
268-
q.append(qv % dict((key, db.literal(item))
269-
for key, item in a.items()))
266+
267+
if not args:
268+
return
269+
270+
m = RE_INSERT_VALUES.match(query)
271+
if m:
272+
q_prefix = m.group(1) % ()
273+
q_values = m.group(2).rstrip()
274+
q_postfix = m.group(3) or ''
275+
assert q_values[0] == '(' and q_values[-1] == ')'
276+
return self._do_execute_many(q_prefix, q_values, q_postfix, args,
277+
self.max_stmt_length,
278+
self._get_db().encoding)
279+
280+
self.rowcount = sum(self.execute(query, arg) for arg in args)
281+
return self.rowcount
282+
283+
def _do_execute_many(self, prefix, values, postfix, args, max_stmt_length, encoding):
284+
conn = self._get_db()
285+
escape = self._escape_args
286+
if isinstance(prefix, text_type):
287+
prefix = prefix.encode(encoding)
288+
if PY2 and isinstance(values, text_type):
289+
values = values.encode(encoding)
290+
if isinstance(postfix, text_type):
291+
postfix = postfix.encode(encoding)
292+
sql = bytearray(prefix)
293+
args = iter(args)
294+
v = values % escape(next(args), conn)
295+
if isinstance(v, text_type):
296+
if PY2:
297+
v = v.encode(encoding)
298+
else:
299+
v = v.encode(encoding, 'surrogateescape')
300+
sql += v
301+
rows = 0
302+
for arg in args:
303+
v = values % escape(arg, conn)
304+
if isinstance(v, text_type):
305+
if PY2:
306+
v = v.encode(encoding)
270307
else:
271-
q.append(qv % tuple([db.literal(item) for item in a]))
272-
except TypeError as msg:
273-
if msg.args[0] in ("not enough arguments for format string",
274-
"not all arguments converted"):
275-
self.errorhandler(self, ProgrammingError, msg.args[0])
308+
v = v.encode(encoding, 'surrogateescape')
309+
if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length:
310+
rows += self.execute(sql + postfix)
311+
sql = bytearray(prefix)
276312
else:
277-
self.errorhandler(self, TypeError, msg)
278-
except (SystemExit, KeyboardInterrupt):
279-
raise
280-
except:
281-
exc, value = sys.exc_info()[:2]
282-
self.errorhandler(self, exc, value)
283-
qs = '\n'.join([query[:p], ',\n'.join(q), query[e:]])
284-
if not PY2:
285-
qs = qs.encode(db.unicode_literal.charset, 'surrogateescape')
286-
r = self._query(qs)
287-
if not self._defer_warnings: self._warning_check()
288-
return r
313+
sql += b','
314+
sql += v
315+
rows += self.execute(sql + postfix)
316+
self.rowcount = rows
317+
return rows
289318

290319
def callproc(self, procname, args=()):
291320
"""Execute stored procedure procname with args

‎tests/test_cursor.py

Copy file name to clipboard
+74Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import py.test
2+
import MySQLdb.cursors
3+
from configdb import connection_factory
4+
5+
6+
_conns = []
7+
_tables = []
8+
9+
def connect(**kwargs):
10+
conn = connection_factory(**kwargs)
11+
_conns.append(conn)
12+
return conn
13+
14+
15+
def teardown_function(function):
16+
if _tables:
17+
c = _conns[0]
18+
cur = c.cursor()
19+
for t in _tables:
20+
cur.execute("DROP TABLE %s" % (t,))
21+
cur.close()
22+
del _tables[:]
23+
24+
for c in _conns:
25+
c.close()
26+
del _conns[:]
27+
28+
29+
def test_executemany():
30+
conn = connect()
31+
cursor = conn.cursor()
32+
33+
cursor.execute("create table test (data varchar(10))")
34+
_tables.append("test")
35+
36+
m = MySQLdb.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%s, %s)")
37+
assert m is not None, 'error parse %s'
38+
assert m.group(3) == '', 'group 3 not blank, bug in RE_INSERT_VALUES?'
39+
40+
m = MySQLdb.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%(id)s, %(name)s)")
41+
assert m is not None, 'error parse %(name)s'
42+
assert m.group(3) == '', 'group 3 not blank, bug in RE_INSERT_VALUES?'
43+
44+
m = MySQLdb.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%(id_name)s, %(name)s)")
45+
assert m is not None, 'error parse %(id_name)s'
46+
assert m.group(3) == '', 'group 3 not blank, bug in RE_INSERT_VALUES?'
47+
48+
m = MySQLdb.cursors.RE_INSERT_VALUES.match("INSERT INTO TEST (ID, NAME) VALUES (%(id_name)s, %(name)s) ON duplicate update")
49+
assert m is not None, 'error parse %(id_name)s'
50+
assert m.group(3) == ' ON duplicate update', 'group 3 not ON duplicate update, bug in RE_INSERT_VALUES?'
51+
52+
# cursor._executed myst bee "insert into test (data) values (0),(1),(2),(3),(4),(5),(6),(7),(8),(9)"
53+
# list args
54+
data = range(10)
55+
cursor.executemany("insert into test (data) values (%s)", data)
56+
assert cursor._executed.endswith(b",(7),(8),(9)"), 'execute many with %s not in one query'
57+
58+
# dict args
59+
data_dict = [{'data': i} for i in range(10)]
60+
cursor.executemany("insert into test (data) values (%(data)s)", data_dict)
61+
assert cursor._executed.endswith(b",(7),(8),(9)"), 'execute many with %(data)s not in one query'
62+
63+
# %% in column set
64+
cursor.execute("""\
65+
CREATE TABLE percent_test (
66+
`A%` INTEGER,
67+
`B%` INTEGER)""")
68+
try:
69+
q = "INSERT INTO percent_test (`A%%`, `B%%`) VALUES (%s, %s)"
70+
assert MySQLdb.cursors.RE_INSERT_VALUES.match(q) is not None
71+
cursor.executemany(q, [(3, 4), (5, 6)])
72+
assert cursor._executed.endswith(b"(3, 4),(5, 6)"), "executemany with %% not in one query"
73+
finally:
74+
cursor.execute("DROP TABLE IF EXISTS percent_test")

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.