diff --git a/MySQLdb/connections.py b/MySQLdb/connections.py index 1511024f..faf3bf6b 100644 --- a/MySQLdb/connections.py +++ b/MySQLdb/connections.py @@ -266,6 +266,48 @@ def string_decoder(s): self.autocommit(autocommit) self.messages = [] + @staticmethod + def _parse_endpoint(connection_string): + db_info = connection_string.replace('mysql://', '').strip() + + user, pwd = None, None + if '@' in db_info: + auth_info, db_info = db_info.split('@') + + user = auth_info + if ':' in auth_info: + user, pwd = auth_info.split(':') + + db = None + if '/' in db_info: + db_info, db = db_info.split('/') + + host = db_info + port = None + if ':' in db_info: + host, port = db_info.split(':') + + return host, port, user, pwd, db + + @classmethod + def string_connection(cls, connection_string, **kwargs): + host, port, user, pwd, db = cls._parse_endpoint(connection_string) + + kwargs['host'] = host + if port: + kwargs['port'] = port + + if user: + kwargs['user'] = user + + if pwd: + kwargs['password'] = pwd + + if db: + kwargs['database'] = db + + return cls(**kwargs) + def autocommit(self, on): on = bool(on) if self.get_autocommit() != on: diff --git a/tests/capabilities.py b/tests/capabilities.py index 31aa398e..bbeb41d8 100644 --- a/tests/capabilities.py +++ b/tests/capabilities.py @@ -21,10 +21,13 @@ class DatabaseTest(unittest.TestCase): create_table_extra = '' rows = 10 debug = False + + def get_connection(self): + return connection_factory(**self.connect_kwargs) def setUp(self): import gc - db = connection_factory(**self.connect_kwargs) + db = self.get_connection() self.connection = db self.cursor = db.cursor() self.BLOBUText = u''.join([unichr(i) for i in range(16384)]) diff --git a/tests/test_MySQLdb_capabilities.py b/tests/test_MySQLdb_capabilities.py index f1887575..0076b2be 100644 --- a/tests/test_MySQLdb_capabilities.py +++ b/tests/test_MySQLdb_capabilities.py @@ -5,8 +5,8 @@ from contextlib import closing import unittest import MySQLdb -from MySQLdb.compat import unicode from MySQLdb import cursors +from MySQLdb.compat import unicode from configdb import connection_factory import warnings @@ -199,6 +199,74 @@ def test_binary_prefix(self): c.execute('SELECT CHARSET(%s)', ('str',)) self.assertEqual(c.fetchall()[0][0], 'utf8') + def endpoint_from_params( + self, host, port=None, user=None, pwd=None, database=None, **kwargs + ): + auth = '' + if user: + auth = '{}:{}@'.format(user, pwd) if pwd else '{}@'.format(user) + + db_port = ':{}'.format(port) if port else '' + db_name = '/{}'.format(database) if database else '' + + # mysql://user:pwd@host:port/db_name + return 'mysql://{}{}{}{}'.format(auth, host, db_port, db_name) + + def parse_connection_tests( + self, host='fake-host.com', port=None, user=None, pwd=None, db=None + ): + + endpoint = self.endpoint_from_params(host, port, user, pwd, db) + found = self.connection._parse_endpoint(connection_string=endpoint) + + self.assertEqual(host, found[0]) + self.assertEqual(port, found[1]) + self.assertEqual(user, found[2]) + self.assertEqual(pwd, found[3]) + self.assertEqual(db, found[4]) + + def test_string_connection_parser_full(self): + self.parse_connection_tests( + host='fake-host.com', port='3308', + user='fake', pwd='123456', db='testing' + ) + + def test_string_connection_host(self): + self.parse_connection_tests(host='my-host.gov') + + def test_string_connection_user(self): + self.parse_connection_tests(user='fake-user') + + def test_string_connection_user_pwd(self): + self.parse_connection_tests(user='fake-user', pwd='mypwd123') + + def test_string_connection_port(self): + self.parse_connection_tests(port='3330') + + def test_string_connection_database(self): + self.parse_connection_tests(db='db_fake') + + def test_string_connection_port_database(self): + self.parse_connection_tests(port='1029', db='db_fake') + + +class test_MySQLdb_string_connection(test_MySQLdb): + + def get_connection(self): + try: + from configparser import ConfigParser + except ImportError: + from ConfigParser import ConfigParser + from configdb import conf_path + from MySQLdb.connections import Connection + + config = ConfigParser() + config.readfp(open(conf_path)) + endpoint = self.endpoint_from_params( + **{k: v for k, v in config.items('MySQLdb-tests')} + ) + return Connection.string_connection(endpoint, **self.connect_kwargs) + if __name__ == '__main__': if test_MySQLdb.leak_test: