diff --git a/MySQLdb/cursors.py b/MySQLdb/cursors.py index 3e4a0e7b..94828862 100644 --- a/MySQLdb/cursors.py +++ b/MySQLdb/cursors.py @@ -73,8 +73,7 @@ def __init__(self, connection): self.messages = [] self.errorhandler = connection.errorhandler self._result = None - self._warnings = 0 - self._info = None + self._warnings = None self.rownumber = None def close(self): @@ -128,29 +127,37 @@ def _check_executed(self): def _warning_check(self): from warnings import warn + db = self._get_db() + + # None => warnings not interrogated for current query yet + # 0 => no warnings exists or have been handled already for this query + if self._warnings is None: + self._warnings = db.warning_count() if self._warnings: + # Only propagate warnings for current query once + warning_count = self._warnings + self._warnings = 0 # When there is next result, fetching warnings cause "command # out of sync" error. if self._result and self._result.has_next: - msg = "There are %d MySQL warnings." % (self._warnings,) + msg = "There are %d MySQL warnings." % (warning_count,) self.messages.append(msg) - warn(msg, self.Warning, 3) + warn(self.Warning(0, msg), stacklevel=3) return - warnings = self._get_db().show_warnings() + warnings = db.show_warnings() if warnings: # This is done in two loops in case # Warnings are set to raise exceptions. for w in warnings: self.messages.append((self.Warning, w)) for w in warnings: - msg = w[-1] - if not PY2 and isinstance(msg, bytes): - msg = msg.decode() - warn(msg, self.Warning, 3) - elif self._info: - self.messages.append((self.Warning, self._info)) - warn(self._info, self.Warning, 3) + warn(self.Warning(*w[1:3]), stacklevel=3) + else: + info = db.info() + if info: + self.messages.append((self.Warning, info)) + warn(self.Warning(0, info), stacklevel=3) def nextset(self): """Advance to the next result set. @@ -180,8 +187,7 @@ def _do_get_result(self): self.description = self._result and self._result.describe() or None self.description_flags = self._result and self._result.field_flags() or None self.lastrowid = db.insert_id() - self._warnings = db.warning_count() - self._info = db.info() + self._warnings = None def setinputsizes(self, *args): """Does nothing, required by DB API.""" diff --git a/tests/test_MySQLdb_capabilities.py b/tests/test_MySQLdb_capabilities.py index 1ec32f71..0adc4095 100644 --- a/tests/test_MySQLdb_capabilities.py +++ b/tests/test_MySQLdb_capabilities.py @@ -3,6 +3,8 @@ from datetime import timedelta import unittest import MySQLdb +from MySQLdb.compat import unicode +from MySQLdb import cursors import warnings @@ -155,6 +157,28 @@ def test_reraise_exception(self): return self.fail("Should raise ProgrammingError") + def test_warning_propagation(self): + with warnings.catch_warnings(): + # Ignore all warnings other than MySQLdb generated ones + warnings.simplefilter("ignore") + warnings.simplefilter("error", category=MySQLdb.Warning) + + # verify for both buffered and unbuffered cursor types + for cursor_class in (cursors.Cursor, cursors.SSCursor): + c = self.connection.cursor(cursor_class) + try: + c.execute("SELECT CAST('124b' AS SIGNED)") + c.fetchall() + except MySQLdb.Warning as e: + # Warnings should have errorcode and string message, just like exceptions + self.assertEqual(len(e.args), 2) + self.assertEqual(e.args[0], 1292) + self.assertTrue(isinstance(e.args[1], unicode)) + else: + self.fail("Should raise Warning") + finally: + c.close() + if __name__ == '__main__': if test_MySQLdb.leak_test: