diff --git a/MySQLdb/__init__.py b/MySQLdb/__init__.py index fc414810..9dc8dabc 100644 --- a/MySQLdb/__init__.py +++ b/MySQLdb/__init__.py @@ -27,6 +27,7 @@ paramstyle = "format" from _mysql import * +from MySQLdb.compat import PY2 from MySQLdb.constants import FIELD_TYPE from MySQLdb.times import Date, Time, Timestamp, \ DateFromTicks, TimeFromTicks, TimestampFromTicks @@ -72,8 +73,12 @@ def test_DBAPISet_set_equality_membership(): def test_DBAPISet_set_inequality_membership(): assert FIELD_TYPE.DATE != STRING -def Binary(x): - return bytes(x) +if PY2: + def Binary(x): + return bytearray(x) +else: + def Binary(x): + return bytes(x) def Connect(*args, **kwargs): """Factory function for connections.Connection.""" diff --git a/MySQLdb/connections.py b/MySQLdb/connections.py index d8406db8..8375041c 100644 --- a/MySQLdb/connections.py +++ b/MySQLdb/connections.py @@ -137,6 +137,10 @@ class object, used to create cursors (keyword only) If True, autocommit is enabled. If None, autocommit isn't set and server default is used. + :param bool binary_prefix: + If set, the '_binary' prefix will be used for raw byte query + arguments (e.g. Binary). This is disabled by default. + There are a number of undocumented, non-standard methods. See the documentation for the MySQL C API for some hints on what they do. """ @@ -174,6 +178,7 @@ class object, used to create cursors (keyword only) use_unicode = kwargs2.pop('use_unicode', use_unicode) sql_mode = kwargs2.pop('sql_mode', '') + binary_prefix = kwargs2.pop('binary_prefix', False) client_flag = kwargs.get('client_flag', 0) client_version = tuple([ numeric_part(n) for n in _mysql.get_client_info().split('.')[:2] ]) @@ -197,7 +202,7 @@ class object, used to create cursors (keyword only) db = proxy(self) def _get_string_literal(): - # Note: string_literal() is called for bytes object on Python 3. + # Note: string_literal() is called for bytes object on Python 3 (via bytes_literal) def string_literal(obj, dummy=None): return db.string_literal(obj) return string_literal @@ -206,13 +211,18 @@ def _get_unicode_literal(): if PY2: # unicode_literal is called for only unicode object. def unicode_literal(u, dummy=None): - return db.literal(u.encode(unicode_literal.charset)) + return db.string_literal(u.encode(unicode_literal.charset)) else: # unicode_literal() is called for arbitrary object. def unicode_literal(u, dummy=None): - return db.literal(str(u).encode(unicode_literal.charset)) + return db.string_literal(str(u).encode(unicode_literal.charset)) return unicode_literal + def _get_bytes_literal(): + def bytes_literal(obj, dummy=None): + return b'_binary' + db.string_literal(obj) + return bytes_literal + def _get_string_decoder(): def string_decoder(s): return s.decode(string_decoder.charset) @@ -220,6 +230,7 @@ def string_decoder(s): string_literal = _get_string_literal() self.unicode_literal = unicode_literal = _get_unicode_literal() + bytes_literal = _get_bytes_literal() self.string_decoder = string_decoder = _get_string_decoder() if not charset: charset = self.character_set_name() @@ -234,7 +245,12 @@ def string_decoder(s): self.converter[FIELD_TYPE.VARCHAR].append((None, string_decoder)) self.converter[FIELD_TYPE.BLOB].append((None, string_decoder)) - self.encoders[bytes] = string_literal + if binary_prefix: + self.encoders[bytes] = string_literal if PY2 else bytes_literal + self.encoders[bytearray] = bytes_literal + else: + self.encoders[bytes] = string_literal + self.encoders[unicode] = unicode_literal self._transactional = self.server_capabilities & CLIENT.TRANSACTIONS if self._transactional: diff --git a/tests/test_MySQLdb_capabilities.py b/tests/test_MySQLdb_capabilities.py index 6f1fe27d..f1887575 100644 --- a/tests/test_MySQLdb_capabilities.py +++ b/tests/test_MySQLdb_capabilities.py @@ -2,10 +2,12 @@ # -*- coding: utf-8 -*- import capabilities from datetime import timedelta +from contextlib import closing import unittest import MySQLdb from MySQLdb.compat import unicode from MySQLdb import cursors +from configdb import connection_factory import warnings @@ -180,6 +182,23 @@ def test_warning_propagation(self): finally: c.close() + def test_binary_prefix(self): + # verify prefix behaviour when enabled, disabled and for default (disabled) + for binary_prefix in (True, False, None): + kwargs = self.connect_kwargs.copy() + # needs to be set to can guarantee CHARSET response for normal strings + kwargs['charset'] = 'utf8' + if binary_prefix != None: + kwargs['binary_prefix'] = binary_prefix + + with closing(connection_factory(**kwargs)) as conn: + with closing(conn.cursor()) as c: + c.execute('SELECT CHARSET(%s)', (MySQLdb.Binary(b'raw bytes'),)) + self.assertEqual(c.fetchall()[0][0], 'binary' if binary_prefix else 'utf8') + # normal strings should not get prefix + c.execute('SELECT CHARSET(%s)', ('str',)) + self.assertEqual(c.fetchall()[0][0], 'utf8') + if __name__ == '__main__': if test_MySQLdb.leak_test: