From f433a80cb0d8d691b25a4d826866fda7916cd2f1 Mon Sep 17 00:00:00 2001 From: Adam Holmberg Date: Wed, 3 Aug 2016 07:07:55 -0500 Subject: [PATCH 01/10] Revert "Merge tag '3.6.0' into cassandra-test" This reverts commit ad76aa321f126cf1e64121219805c8819b29a0da, reversing changes made to 0f81911ce83cfc6e69b593692d63cb3e7993fadb. --- CHANGELOG.rst | 31 --- cassandra/__init__.py | 2 +- cassandra/cluster.py | 92 +++----- cassandra/concurrent.py | 14 +- cassandra/connection.py | 18 +- cassandra/cqlengine/columns.py | 92 +------- cassandra/cqlengine/management.py | 36 +-- cassandra/cqlengine/models.py | 33 +-- cassandra/cqlengine/query.py | 74 +----- cassandra/cqlengine/statements.py | 4 +- cassandra/cqltypes.py | 24 +- cassandra/encoder.py | 3 +- cassandra/io/eventletreactor.py | 14 ++ cassandra/io/geventreactor.py | 18 +- cassandra/io/libevreactor.py | 10 +- cassandra/metadata.py | 27 ++- cassandra/metrics.py | 35 +-- cassandra/numpy_parser.pyx | 30 +-- cassandra/protocol.py | 59 ++--- cassandra/query.py | 31 +-- cassandra/row_parser.pyx | 4 +- cassandra/type_codes.py | 1 + cassandra/util.py | 4 - docs.yaml | 5 - docs/api/cassandra/cqlengine/models.rst | 6 +- docs/api/cassandra/cqlengine/query.rst | 8 - docs/cqlengine/queryset.rst | 36 --- docs/getting_started.rst | 2 +- test-requirements.txt | 2 +- tests/integration/__init__.py | 84 +------ tests/integration/cqlengine/__init__.py | 2 +- .../cqlengine/columns/test_validation.py | 223 +----------------- .../model/test_class_construction.py | 23 +- .../integration/cqlengine/model/test_model.py | 36 +-- .../cqlengine/model/test_model_io.py | 63 +---- .../cqlengine/model/test_updates.py | 36 +-- .../integration/cqlengine/query/test_named.py | 4 +- .../cqlengine/query/test_queryoperators.py | 6 +- .../cqlengine/query/test_queryset.py | 69 +----- .../cqlengine/test_context_query.py | 127 ---------- .../cqlengine/test_lwt_conditional.py | 15 -- tests/integration/cqlengine/test_ttl.py | 43 +--- tests/integration/long/test_ssl.py | 42 +--- tests/integration/standard/test_cluster.py | 89 +------ tests/integration/standard/test_connection.py | 2 +- .../standard/test_custom_protocol_handler.py | 4 +- .../standard/test_cython_protocol_handlers.py | 91 ++++--- tests/integration/standard/test_metadata.py | 134 ++++------- tests/integration/standard/test_metrics.py | 117 +-------- tests/integration/standard/test_query.py | 143 +++++------ tests/integration/standard/test_types.py | 4 +- tests/integration/standard/utils.py | 1 - tests/unit/cqlengine/__init__.py | 14 -- tests/unit/cqlengine/test_columns.py | 68 ------ tests/unit/test_concurrent.py | 19 -- tests/unit/test_connection.py | 8 +- tests/unit/test_host_connection_pool.py | 17 +- tests/unit/test_parameter_binding.py | 8 +- tests/unit/test_response_future.py | 12 +- 59 files changed, 461 insertions(+), 1758 deletions(-) delete mode 100644 tests/integration/cqlengine/test_context_query.py delete mode 100644 tests/unit/cqlengine/__init__.py delete mode 100644 tests/unit/cqlengine/test_columns.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 3db920828b..273657131a 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,34 +1,3 @@ -3.6.0 -===== -August 1, 2016 - -Features --------- -* Handle null values in NumpyProtocolHandler (PYTHON-553) -* Collect greplin scales stats per cluster (PYTHON-561) -* Update mock unit test dependency requirement (PYTHON-591) -* Handle Missing CompositeType metadata following C* upgrade (PYTHON-562) -* Improve Host.is_up state for HostDistance.IGNORED hosts (PYTHON-551) -* Utilize v2 protocol's ability to skip result set metadata for prepared statement execution (PYTHON-71) -* Return from Cluster.connect() when first contact point connection(pool) is opened (PYTHON-105) -* cqlengine: Add ContextQuery to allow cqlengine models to switch the keyspace context easily (PYTHON-598) - -Bug Fixes ---------- -* Fix geventreactor with SSL support (PYTHON-600) -* Don't downgrade protocol version if explicitly set (PYTHON-537) -* Nonexistent contact point tries to connect indefinitely (PYTHON-549) -* Execute_concurrent can exceed max recursion depth in failure mode (PYTHON-585) -* Libev loop shutdown race (PYTHON-578) -* Include aliases in DCT type string (PYTHON-579) -* cqlengine: Comparison operators for Columns (PYTHON-595) -* cqlengine: disentangle default_time_to_live table option from model query default TTL (PYTHON-538) -* cqlengine: pk__token column name issue with the equality operator (PYTHON-584) -* cqlengine: Fix "__in" filtering operator converts True to string "True" automatically (PYTHON-596) -* cqlengine: Avoid LWTExceptions when updating columns that are part of the condition (PYTHON-580) -* cqlengine: Cannot execute a query when the filter contains all columns (PYTHON-599) -* cqlengine: routing key computation issue when a primary key column is overriden by model inheritance (PYTHON-576) - 3.5.0 ===== June 27, 2016 diff --git a/cassandra/__init__.py b/cassandra/__init__.py index 1a02d8a892..c8212c70e3 100644 --- a/cassandra/__init__.py +++ b/cassandra/__init__.py @@ -22,7 +22,7 @@ def emit(self, record): logging.getLogger('cassandra').addHandler(NullHandler()) -__version_info__ = (3, 6, 0) +__version_info__ = (3, 5, 0) __version__ = '.'.join(map(str, __version_info__)) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 536ae71c14..99509d2233 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -20,7 +20,7 @@ import atexit from collections import defaultdict, Mapping -from concurrent.futures import ThreadPoolExecutor, FIRST_COMPLETED, wait as wait_futures +from concurrent.futures import ThreadPoolExecutor, wait as wait_futures from copy import copy from functools import partial, wraps from itertools import groupby, count @@ -356,11 +356,10 @@ class Cluster(object): """ The maximum version of the native protocol to use. - If not set in the constructor, the driver will automatically downgrade - version based on a negotiation with the server, but it is most efficient - to set this to the maximum supported by your version of Cassandra. - Setting this will also prevent conflicting versions negotiated if your - cluster is upgraded. + The driver will automatically downgrade version based on a negotiation with + the server, but it is most efficient to set this to the maximum supported + by your version of Cassandra. Setting this will also prevent conflicting + versions negotiated if your cluster is upgraded. Version 2 of the native protocol adds support for lightweight transactions, batch operations, and automatic query paging. The v2 protocol is @@ -389,8 +388,6 @@ class Cluster(object): +-------------------+-------------------+ | 2.2 | 1, 2, 3, 4 | +-------------------+-------------------+ - | 3.x | 3, 4 | - +-------------------+-------------------+ """ compression = True @@ -722,7 +719,6 @@ def token_metadata_enabled(self, enabled): _prepared_statements = None _prepared_statement_lock = None _idle_heartbeat = None - _protocol_version_explicit = False _user_types = None """ @@ -746,7 +742,7 @@ def __init__(self, ssl_options=None, sockopts=None, cql_version=None, - protocol_version=_NOT_SET, + protocol_version=4, executor_threads=2, max_schema_agreement_wait=10, control_connection_timeout=2.0, @@ -781,11 +777,7 @@ def __init__(self, for endpoint in socket.getaddrinfo(a, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM)] self.compression = compression - - if protocol_version is not _NOT_SET: - self.protocol_version = protocol_version - self._protocol_version_explicit = True - + self.protocol_version = protocol_version self.auth_provider = auth_provider if load_balancing_policy is not None: @@ -1125,9 +1117,6 @@ def _make_connection_kwargs(self, address, kwargs_dict): return kwargs_dict def protocol_downgrade(self, host_addr, previous_version): - if self._protocol_version_explicit: - raise DriverException("ProtocolError returned from server while using explicitly set client protocol_version %d" % (previous_version,)) - new_version = previous_version - 1 if new_version < self.protocol_version: if new_version >= MIN_SUPPORTED_VERSION: @@ -1138,7 +1127,7 @@ def protocol_downgrade(self, host_addr, previous_version): else: raise DriverException("Cannot downgrade protocol version (%d) below minimum supported version: %d" % (new_version, MIN_SUPPORTED_VERSION)) - def connect(self, keyspace=None, wait_for_all_pools=False): + def connect(self, keyspace=None): """ Creates and returns a new :class:`~.Session` object. If `keyspace` is specified, that keyspace will be the default keyspace for @@ -1165,13 +1154,6 @@ def connect(self, keyspace=None, wait_for_all_pools=False): try: self.control_connection.connect() - - # we set all contact points up for connecting, but we won't infer state after this - for address in self.contact_points_resolved: - h = self.metadata.get_host(address) - if h and self.profile_manager.distance(h) == HostDistance.IGNORED: - h.is_up = None - log.debug("Control connection created") except Exception: log.exception("Control connection failed to connect, " @@ -1185,9 +1167,9 @@ def connect(self, keyspace=None, wait_for_all_pools=False): self._idle_heartbeat = ConnectionHeartbeat(self.idle_heartbeat_interval, self.get_connection_holders) self._is_setup = True - session = self._new_session(keyspace) - if wait_for_all_pools: - wait_futures(session._initial_connect_futures) + session = self._new_session() + if keyspace: + session.set_keyspace(keyspace) return session def get_connection_holders(self): @@ -1231,8 +1213,8 @@ def __enter__(self): def __exit__(self, *args): self.shutdown() - def _new_session(self, keyspace): - session = Session(self, self.metadata.all_hosts(), keyspace) + def _new_session(self): + session = Session(self, self.metadata.all_hosts()) self._session_register_user_types(session) self.sessions.add(session) return session @@ -1352,7 +1334,6 @@ def on_up(self, host): else: if not have_future: with host.lock: - host.set_up() host._currently_handling_node_up = False # for testing purposes @@ -1391,11 +1372,10 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False): return with host.lock: - was_up = host.is_up - host.set_down() - if (not was_up and not expect_host_to_be_down) or host.is_currently_reconnecting(): + if (not host.is_up and not expect_host_to_be_down) or host.is_currently_reconnecting(): return + host.set_down() log.warning("Host %s has been marked down", host) @@ -1908,10 +1888,9 @@ def default_serial_consistency_level(self, cl): _profile_manager = None _metrics = None - def __init__(self, cluster, hosts, keyspace=None): + def __init__(self, cluster, hosts): self.cluster = cluster self.hosts = hosts - self.keyspace = keyspace self._lock = RLock() self._pools = {} @@ -1922,13 +1901,14 @@ def __init__(self, cluster, hosts, keyspace=None): self.encoder = Encoder() # create connection pools in parallel - self._initial_connect_futures = set() + futures = [] for host in hosts: future = self.add_or_renew_pool(host, is_host_addition=False) - if future: - self._initial_connect_futures.add(future) - wait_futures(self._initial_connect_futures, return_when=FIRST_COMPLETED) + if future is not None: + futures.append(future) + for future in futures: + future.result() def execute(self, query, parameters=None, timeout=_NOT_SET, trace=False, custom_payload=None, execution_profile=EXEC_PROFILE_DEFAULT): """ @@ -2065,11 +2045,11 @@ def _create_response_future(self, query, parameters, trace, custom_payload, time query_string, cl, serial_cl, fetch_size, timestamp=timestamp) elif isinstance(query, BoundStatement): - prepared_statement = query.prepared_statement message = ExecuteMessage( - prepared_statement.query_id, query.values, cl, + query.prepared_statement.query_id, query.values, cl, serial_cl, fetch_size, - timestamp=timestamp, skip_meta=bool(prepared_statement.result_metadata)) + timestamp=timestamp) + prepared_statement = query.prepared_statement elif isinstance(query, BatchStatement): if self._protocol_version < 2: raise UnsupportedOperation( @@ -2144,14 +2124,14 @@ def prepare(self, query, custom_payload=None): future = ResponseFuture(self, message, query=None, timeout=self.default_timeout) try: future.send_request() - query_id, bind_metadata, pk_indexes, result_metadata = future.result() + query_id, column_metadata, pk_indexes = future.result() except Exception: log.exception("Error preparing query:") raise prepared_statement = PreparedStatement.from_message( - query_id, bind_metadata, pk_indexes, self.cluster.metadata, query, self.keyspace, - self._protocol_version, result_metadata) + query_id, column_metadata, pk_indexes, self.cluster.metadata, query, self.keyspace, + self._protocol_version) prepared_statement.custom_payload = future.custom_payload self.cluster.add_prepared(query_id, prepared_statement) @@ -2209,7 +2189,7 @@ def shutdown(self): else: self.is_shutdown = True - for pool in list(self._pools.values()): + for pool in self._pools.values(): pool.shutdown() def __enter__(self): @@ -2794,8 +2774,9 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, for old_host in self._cluster.metadata.all_hosts(): if old_host.address != connection.host and old_host.address not in found_hosts: should_rebuild_token_map = True - log.debug("[control connection] Removing host not found in peers metadata: %r", old_host) - self._cluster.remove_host(old_host) + if old_host.address not in self._cluster.contact_points: + log.debug("[control connection] Removing host not found in peers metadata: %r", old_host) + self._cluster.remove_host(old_host) log.debug("[control connection] Finished fetching ring info") if partitioner and should_rebuild_token_map: @@ -2948,13 +2929,14 @@ def _get_schema_mismatches(self, peers_result, local_result, local_address): if local_row.get("schema_version"): versions[local_row.get("schema_version")].add(local_address) + pm = self._cluster.profile_manager for row in peers_result: schema_ver = row.get('schema_version') if not schema_ver: continue addr = self._rpc_from_peer_row(row) peer = self._cluster.metadata.get_host(addr) - if peer and peer.is_up is not False: + if peer and peer.is_up and pm.distance(peer) != HostDistance.IGNORED: versions[schema_ver].add(addr) if len(versions) == 1: @@ -3272,9 +3254,7 @@ def _query(self, host, message=None, cb=None): # TODO get connectTimeout from cluster settings connection, request_id = pool.borrow_connection(timeout=2.0) self._connection = connection - result_meta = self.prepared_statement.result_metadata if self.prepared_statement else [] - connection.send_msg(message, request_id, cb=cb, encoder=self._protocol_handler.encode_message, decoder=self._protocol_handler.decode_message, - result_metadata=result_meta) + connection.send_msg(message, request_id, cb=cb, encoder=self._protocol_handler.encode_message, decoder=self._protocol_handler.decode_message) return request_id except NoConnectionsAvailable as exc: log.debug("All connections for host %s are at capacity, moving to the next host", host) @@ -3777,8 +3757,8 @@ def add_callbacks(self, callback, errback, def clear_callbacks(self): with self._callback_lock: - self._callbacks = [] - self._errbacks = [] + self._callback = [] + self._errback = [] def __str__(self): result = "(no result yet)" if self._final_result is _NOT_SET else self._final_result diff --git a/cassandra/concurrent.py b/cassandra/concurrent.py index a08c0292e3..48cbab3e24 100644 --- a/cassandra/concurrent.py +++ b/cassandra/concurrent.py @@ -94,8 +94,6 @@ def execute_concurrent(session, statements_and_parameters, concurrency=100, rais class _ConcurrentExecutor(object): - max_error_recursion = 100 - def __init__(self, session, statements_and_params): self.session = session self._enum_statements = enumerate(iter(statements_and_params)) @@ -104,7 +102,6 @@ def __init__(self, session, statements_and_params): self._results_queue = [] self._current = 0 self._exec_count = 0 - self._exec_depth = 0 def execute(self, concurrency, fail_fast): self._fail_fast = fail_fast @@ -128,7 +125,6 @@ def _execute_next(self): pass def _execute(self, idx, statement, params): - self._exec_depth += 1 try: future = self.session.execute_async(statement, params, timeout=None) args = (future, idx) @@ -139,15 +135,7 @@ def _execute(self, idx, statement, params): # exc_info with fail_fast to preserve stack trace info when raising on the client thread # (matches previous behavior -- not sure why we wouldn't want stack trace in the other case) e = sys.exc_info() if self._fail_fast and six.PY2 else exc - - # If we're not failing fast and all executions are raising, there is a chance of recursing - # here as subsequent requests are attempted. If we hit this threshold, schedule this result/retry - # and let the event loop thread return. - if self._exec_depth < self.max_error_recursion: - self._put_result(e, idx, False) - else: - self.session.submit(self._put_result, e, idx, False) - self._exec_depth -= 1 + self._put_result(e, idx, False) def _on_success(self, result, future, idx): future.clear_callbacks() diff --git a/cassandra/connection.py b/cassandra/connection.py index 11da8a4afe..f43edc4b5d 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -149,8 +149,8 @@ class ProtocolVersionUnsupported(ConnectionException): Server rejected startup message due to unsupported protocol version """ def __init__(self, host, startup_version): - msg = "Unsupported protocol version on %s: %d" % (host, startup_version) - super(ProtocolVersionUnsupported, self).__init__(msg, host) + super(ProtocolVersionUnsupported, self).__init__("Unsupported protocol version on %s: %d", + (host, startup_version)) self.startup_version = startup_version @@ -345,7 +345,6 @@ def _connect_socket(self): self._socket = self._ssl_impl.wrap_socket(self._socket, **self.ssl_options) self._socket.settimeout(self.connect_timeout) self._socket.connect(sockaddr) - self._socket.settimeout(None) if self._check_hostname: ssl.match_hostname(self._socket.getpeercert(), self.host) sockerr = None @@ -405,7 +404,7 @@ def try_callback(cb): id(self), self.host, exc_info=True) # run first callback from this thread to ensure pool state before leaving - cb, _, _ = requests.popitem()[1] + cb, _ = requests.popitem()[1] try_callback(cb) if not requests: @@ -415,7 +414,7 @@ def try_callback(cb): # The default callback and retry logic is fairly expensive -- we don't # want to tie up the event thread when there are many requests def err_all_callbacks(): - for cb, _, _ in requests.values(): + for cb, _ in requests.values(): try_callback(cb) if len(requests) < Connection.CALLBACK_ERR_THREAD_THRESHOLD: err_all_callbacks() @@ -446,7 +445,7 @@ def handle_pushed(self, response): except Exception: log.exception("Pushed event handler errored, ignoring:") - def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=None): + def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message): if self.is_defunct: raise ConnectionShutdown("Connection to %s is defunct" % self.host) elif self.is_closed: @@ -454,7 +453,7 @@ def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message, # queue the decoder function with the request # this allows us to inject custom functions per request to encode, decode messages - self._requests[request_id] = (cb, decoder, result_metadata) + self._requests[request_id] = (cb, decoder) self.push(encoder(msg, request_id, self.protocol_version, compressor=self.compressor)) return request_id @@ -579,9 +578,8 @@ def process_msg(self, header, body): if stream_id < 0: callback = None decoder = ProtocolHandler.decode_message - result_metadata = None else: - callback, decoder, result_metadata = self._requests.pop(stream_id) + callback, decoder = self._requests.pop(stream_id, None) with self.lock: self.request_ids.append(stream_id) @@ -589,7 +587,7 @@ def process_msg(self, header, body): try: response = decoder(header.version, self.user_type_map, stream_id, - header.flags, header.opcode, body, self.decompressor, result_metadata) + header.flags, header.opcode, body, self.decompressor) except Exception as exc: log.exception("Error decoding response from Cassandra. " "%s; buffer: %r", header, self._iobuf.getvalue()) diff --git a/cassandra/cqlengine/columns.py b/cassandra/cqlengine/columns.py index 0bb52d6bff..14b70915a7 100644 --- a/cassandra/cqlengine/columns.py +++ b/cassandra/cqlengine/columns.py @@ -168,36 +168,6 @@ def __init__(self, self.position = Column.instance_counter Column.instance_counter += 1 - def __ne__(self, other): - if isinstance(other, Column): - return self.position != other.position - return NotImplemented - - def __eq__(self, other): - if isinstance(other, Column): - return self.position == other.position - return NotImplemented - - def __lt__(self, other): - if isinstance(other, Column): - return self.position < other.position - return NotImplemented - - def __le__(self, other): - if isinstance(other, Column): - return self.position <= other.position - return NotImplemented - - def __gt__(self, other): - if isinstance(other, Column): - return self.position > other.position - return NotImplemented - - def __ge__(self, other): - if isinstance(other, Column): - return self.position >= other.position - return NotImplemented - def validate(self, value): """ Returns a cleaned and validated value. Raises a ValidationError @@ -309,6 +279,13 @@ def to_database(self, value): Bytes = Blob +class Ascii(Column): + """ + Stores a US-ASCII character string + """ + db_type = 'ascii' + + class Inet(Column): """ Stores an IP address in IPv4 or IPv6 format @@ -328,68 +305,25 @@ def __init__(self, min_length=None, max_length=None, **kwargs): Defaults to 1 if this is a ``required`` column. Otherwise, None. :param int max_length: Sets the maximum length of this string, for validation purposes. """ - self.min_length = ( - 1 if not min_length and kwargs.get('required', False) - else min_length) + self.min_length = min_length or (1 if kwargs.get('required', False) else None) self.max_length = max_length - - if self.min_length is not None: - if self.min_length < 0: - raise ValueError( - 'Minimum length is not allowed to be negative.') - - if self.max_length is not None: - if self.max_length < 0: - raise ValueError( - 'Maximum length is not allowed to be negative.') - - if self.min_length is not None and self.max_length is not None: - if self.max_length < self.min_length: - raise ValueError( - 'Maximum length must be greater or equal ' - 'to minimum length.') - super(Text, self).__init__(**kwargs) def validate(self, value): value = super(Text, self).validate(value) + if value is None: + return if not isinstance(value, (six.string_types, bytearray)) and value is not None: raise ValidationError('{0} {1} is not a string'.format(self.column_name, type(value))) - if self.max_length is not None: - if value and len(value) > self.max_length: + if self.max_length: + if len(value) > self.max_length: raise ValidationError('{0} is longer than {1} characters'.format(self.column_name, self.max_length)) if self.min_length: - if (self.min_length and not value) or len(value) < self.min_length: + if len(value) < self.min_length: raise ValidationError('{0} is shorter than {1} characters'.format(self.column_name, self.min_length)) return value -class Ascii(Text): - """ - Stores a US-ASCII character string - """ - db_type = 'ascii' - - def validate(self, value): - """ Only allow ASCII and None values. - - Check against US-ASCII, a.k.a. 7-bit ASCII, a.k.a. ISO646-US, a.k.a. - the Basic Latin block of the Unicode character set. - - Source: https://github.com/apache/cassandra/blob - /3dcbe90e02440e6ee534f643c7603d50ca08482b/src/java/org/apache/cassandra - /serializers/AsciiSerializer.java#L29 - """ - value = super(Ascii, self).validate(value) - if value: - charset = value if isinstance( - value, (bytearray, )) else map(ord, value) - if not set(range(128)).issuperset(charset): - raise ValidationError( - '{!r} is not an ASCII string.'.format(value)) - return value - - class Integer(Column): """ Stores a 32-bit signed integer value diff --git a/cassandra/cqlengine/management.py b/cassandra/cqlengine/management.py index cc2a34599f..6978964ad0 100644 --- a/cassandra/cqlengine/management.py +++ b/cassandra/cqlengine/management.py @@ -21,7 +21,7 @@ from cassandra import metadata from cassandra.cqlengine import CQLEngineException -from cassandra.cqlengine import columns, query +from cassandra.cqlengine import columns from cassandra.cqlengine.connection import execute, get_cluster from cassandra.cqlengine.models import Model from cassandra.cqlengine.named import NamedTable @@ -119,12 +119,10 @@ def _get_index_name_by_column(table, column_name): return index_metadata.name -def sync_table(model, keyspaces=None): +def sync_table(model): """ Inspects the model and creates / updates the corresponding table and columns. - If `keyspaces` is specified, the table will be synched for all specified keyspaces. Note that the `Model.__keyspace__` is ignored in that case. - Any User Defined Types used in the table are implicitly synchronized. This function can only add fields that are not part of the primary key. @@ -137,20 +135,6 @@ def sync_table(model, keyspaces=None): *There are plans to guard schema-modifying functions with an environment-driven conditional.* """ - - if keyspaces: - if not isinstance(keyspaces, (list, tuple)): - raise ValueError('keyspaces must be a list or a tuple.') - - for keyspace in keyspaces: - with query.ContextQuery(model, keyspace=keyspace) as m: - _sync_table(m) - else: - _sync_table(model) - - -def _sync_table(model): - if not _allow_schema_modification(): return @@ -447,29 +431,15 @@ def _update_options(model): return False -def drop_table(model, keyspaces=None): +def drop_table(model): """ Drops the table indicated by the model, if it exists. - If `keyspaces` is specified, the table will be dropped for all specified keyspaces. Note that the `Model.__keyspace__` is ignored in that case. - **This function should be used with caution, especially in production environments. Take care to execute schema modifications in a single context (i.e. not concurrently with other clients).** *There are plans to guard schema-modifying functions with an environment-driven conditional.* """ - - if keyspaces: - if not isinstance(keyspaces, (list, tuple)): - raise ValueError('keyspaces must be a list or a tuple.') - - for keyspace in keyspaces: - with query.ContextQuery(model, keyspace=keyspace) as m: - _drop_table(m) - else: - _drop_table(model) - -def _drop_table(model): if not _allow_schema_modification(): return diff --git a/cassandra/cqlengine/models.py b/cassandra/cqlengine/models.py index 41dfc77770..e940955ed4 100644 --- a/cassandra/cqlengine/models.py +++ b/cassandra/cqlengine/models.py @@ -352,7 +352,7 @@ class MultipleObjectsReturned(_MultipleObjectsReturned): _table_name = None # used internally to cache a derived table name def __init__(self, **values): - self._ttl = None + self._ttl = self.__default_ttl__ self._timestamp = None self._conditional = None self._batch = None @@ -361,11 +361,7 @@ def __init__(self, **values): self._values = {} for name, column in self._columns.items(): - # Set default values on instantiation. Thanks to this, we don't have - # to wait anylonger for a call to validate() to have CQLengine set - # default columns values. - column_default = column.get_default() if column.has_default else None - value = values.get(name, column_default) + value = values.get(name) if value is not None or isinstance(column, columns.BaseContainerColumn): value = column.to_python(value) value_mngr = column.value_manager(self, column, value) @@ -695,6 +691,7 @@ def save(self): self._set_persisted() + self._ttl = self.__default_ttl__ self._timestamp = None return self @@ -741,6 +738,7 @@ def update(self, **values): self._set_persisted() + self._ttl = self.__default_ttl__ self._timestamp = None return self @@ -796,10 +794,17 @@ def __new__(cls, name, bases, attrs): # short circuit __discriminator_value__ inheritance attrs['__discriminator_value__'] = attrs.get('__discriminator_value__') - # TODO __default__ttl__ should be removed in the next major release options = attrs.get('__options__') or {} attrs['__default_ttl__'] = options.get('default_time_to_live') + def _transform_column(col_name, col_obj): + column_dict[col_name] = col_obj + if col_obj.primary_key: + primary_keys[col_name] = col_obj + col_obj.set_column_name(col_name) + # set properties + attrs[col_name] = ColumnDescriptor(col_obj) + column_definitions = [(k, v) for k, v in attrs.items() if isinstance(v, columns.Column)] column_definitions = sorted(column_definitions, key=lambda x: x[1].position) @@ -844,14 +849,6 @@ def _get_polymorphic_base(bases): has_partition_keys = any(v.partition_key for (k, v) in column_definitions) - def _transform_column(col_name, col_obj): - column_dict[col_name] = col_obj - if col_obj.primary_key: - primary_keys[col_name] = col_obj - col_obj.set_column_name(col_name) - # set properties - attrs[col_name] = ColumnDescriptor(col_obj) - partition_key_index = 0 # transform column definitions for k, v in column_definitions: @@ -871,12 +868,6 @@ def _transform_column(col_name, col_obj): if v.partition_key: v._partition_key_index = partition_key_index partition_key_index += 1 - - overriding = column_dict.get(k) - if overriding: - v.position = overriding.position - v.partition_key = overriding.partition_key - v._partition_key_index = overriding._partition_key_index _transform_column(k, v) partition_keys = OrderedDict(k for k in primary_keys.items() if k[1].partition_key) diff --git a/cassandra/cqlengine/query.py b/cassandra/cqlengine/query.py index e996baea3e..10d27ab580 100644 --- a/cassandra/cqlengine/query.py +++ b/cassandra/cqlengine/query.py @@ -136,8 +136,6 @@ class BatchQuery(object): Handles the batching of queries http://www.datastax.com/docs/1.2/cql_cli/cql/BATCH - - See :doc:`/cqlengine/batches` for more details. """ warn_multiple_exec = True @@ -261,46 +259,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.execute() -class ContextQuery(object): - """ - A Context manager to allow a Model to switch context easily. Presently, the context only - specifies a keyspace for model IO. - - For example: - - .. code-block:: python - - with ContextQuery(Automobile, keyspace='test2') as A: - A.objects.create(manufacturer='honda', year=2008, model='civic') - print len(A.objects.all()) # 1 result - - with ContextQuery(Automobile, keyspace='test4') as A: - print len(A.objects.all()) # 0 result - - """ - - def __init__(self, model, keyspace=None): - """ - :param model: A model. This should be a class type, not an instance. - :param keyspace: (optional) A keyspace name - """ - from cassandra.cqlengine import models - - if not issubclass(model, models.Model): - raise CQLEngineException("Models must be derived from base Model.") - - ks = keyspace if keyspace else model.__keyspace__ - new_type = type(model.__name__, (model,), {'__keyspace__': ks}) - - self.model = new_type - - def __enter__(self): - return self.model - - def __exit__(self, exc_type, exc_val, exc_tb): - return - - class AbstractQuerySet(object): def __init__(self, model): @@ -341,7 +299,7 @@ def __init__(self, model): self._count = None self._batch = None - self._ttl = None + self._ttl = getattr(model, '__default_ttl__', None) self._consistency = None self._timestamp = None self._if_not_exists = False @@ -374,7 +332,7 @@ def __call__(self, *args, **kwargs): def __deepcopy__(self, memo): clone = self.__class__(self.model) for k, v in self.__dict__.items(): - if k in ['_con', '_cur', '_result_cache', '_result_idx', '_result_generator', '_construct_result']: # don't clone these, which are per-request-execution + if k in ['_con', '_cur', '_result_cache', '_result_idx', '_result_generator']: # don't clone these clone.__dict__[k] = None elif k == '_batch': # we need to keep the same batch instance across @@ -587,7 +545,7 @@ def _parse_filter_arg(self, arg): if len(statement) == 1: return arg, None elif len(statement) == 2: - return (statement[0], statement[1]) if arg != 'pk__token' else (arg, None) + return statement[0], statement[1] else: raise QueryException("Can't parse '{0}'".format(arg)) @@ -996,8 +954,7 @@ class ModelQuerySet(AbstractQuerySet): def _validate_select_where(self): """ Checks that a filterset will not create invalid select statement """ # check that there's either a =, a IN or a CONTAINS (collection) relationship with a primary key or indexed field - equal_ops = [self.model._get_column_by_db_name(w.field) \ - for w in self._where if isinstance(w.operator, EqualsOperator) and not isinstance(w.value, Token)] + equal_ops = [self.model._get_column_by_db_name(w.field) for w in self._where if isinstance(w.operator, EqualsOperator)] token_comparison = any([w for w in self._where if isinstance(w.value, Token)]) if not any(w.primary_key or w.index for w in equal_ops) and not token_comparison and not self._allow_filtering: raise QueryException(('Where clauses require either =, a IN or a CONTAINS (collection) ' @@ -1014,9 +971,6 @@ def _select_fields(self): fields = self.model._columns.keys() if self._defer_fields: fields = [f for f in fields if f not in self._defer_fields] - # select the partition keys if all model fields are set defer - if not fields: - fields = self.model._partition_keys if self._only_fields: fields = [f for f in fields if f in self._only_fields] if not fields: @@ -1200,7 +1154,6 @@ class Row(Model): return nulled_columns = set() - updated_columns = set() us = UpdateStatement(self.column_family_name, where=self._where, ttl=self._ttl, timestamp=self._timestamp, conditionals=self._conditional, if_exists=self._if_exists) for name, val in values.items(): @@ -1221,16 +1174,13 @@ class Row(Model): continue us.add_update(col, val, operation=col_op) - updated_columns.add(col_name) if us.assignments: self._execute(us) if nulled_columns: - delete_conditional = [condition for condition in self._conditional - if condition.field not in updated_columns] if self._conditional else None ds = DeleteStatement(self.column_family_name, fields=nulled_columns, - where=self._where, conditionals=delete_conditional, if_exists=self._if_exists) + where=self._where, conditionals=self._conditional, if_exists=self._if_exists) self._execute(ds) @@ -1277,11 +1227,11 @@ def batch(self, batch_obj): self._batch = batch_obj return self - def _delete_null_columns(self, conditionals=None): + def _delete_null_columns(self): """ executes a delete query to remove columns that have changed to null """ - ds = DeleteStatement(self.column_family_name, conditionals=conditionals, if_exists=self._if_exists) + ds = DeleteStatement(self.column_family_name, conditionals=self._conditional, if_exists=self._if_exists) deleted_fields = False for _, v in self.instance._values.items(): col = v.column @@ -1315,8 +1265,6 @@ def update(self): conditionals=self._conditional, if_exists=self._if_exists) for name, col in self.instance._clustering_keys.items(): null_clustering_key = null_clustering_key and col._val_is_null(getattr(self.instance, name, None)) - - updated_columns = set() # get defined fields and their column names for name, col in self.model._columns.items(): # if clustering key is null, don't include non static columns @@ -1334,7 +1282,6 @@ def update(self): static_changed_only = static_changed_only and col.static statement.add_update(col, val, previous=val_mgr.previous_value) - updated_columns.add(col.db_field_name) if statement.assignments: for name, col in self.model._primary_keys.items(): @@ -1345,10 +1292,7 @@ def update(self): self._execute(statement) if not null_clustering_key: - # remove conditions on fields that have been updated - delete_conditionals = [condition for condition in self._conditional - if condition.field not in updated_columns] if self._conditional else None - self._delete_null_columns(delete_conditionals) + self._delete_null_columns() def save(self): """ @@ -1397,7 +1341,7 @@ def delete(self): ds = DeleteStatement(self.column_family_name, timestamp=self._timestamp, conditionals=self._conditional, if_exists=self._if_exists) for name, col in self.model._primary_keys.items(): val = getattr(self.instance, name) - if val is None and not col.partition_key: + if val is None and not col.parition_key: continue ds.add_where(col, EqualsOperator(), val) self._execute(ds) diff --git a/cassandra/cqlengine/statements.py b/cassandra/cqlengine/statements.py index 44ae165e8b..3867704a77 100644 --- a/cassandra/cqlengine/statements.py +++ b/cassandra/cqlengine/statements.py @@ -35,7 +35,9 @@ def __init__(self, value): def __unicode__(self): from cassandra.encoder import cql_quote - if isinstance(self.value, (list, tuple)): + if isinstance(self.value, bool): + return 'true' if self.value else 'false' + elif isinstance(self.value, (list, tuple)): return '[' + ', '.join([cql_quote(v) for v in self.value]) + ']' elif isinstance(self.value, dict): return '{' + ', '.join([cql_quote(k) + ':' + cql_quote(v) for k, v in self.value.items()]) + '}' diff --git a/cassandra/cqltypes.py b/cassandra/cqltypes.py index b6a720e6c9..7eb0a2df58 100644 --- a/cassandra/cqltypes.py +++ b/cassandra/cqltypes.py @@ -107,7 +107,7 @@ def __new__(metacls, name, bases, dct): cls = type.__new__(metacls, name, bases, dct) if not name.startswith('_'): _casstypes[name] = cls - if not cls.typename.startswith(apache_cassandra_type_prefix): + if not cls.typename.startswith("'org"): _cqltypes[cls.typename] = cls return cls @@ -682,8 +682,6 @@ class VarcharType(UTF8Type): class _ParameterizedType(_CassandraType): - num_subtypes = 'UNKNOWN' - @classmethod def deserialize(cls, byts, protocol_version): if not cls.subtypes: @@ -804,6 +802,7 @@ def serialize_safe(cls, themap, protocol_version): class TupleType(_ParameterizedType): typename = 'tuple' + num_subtypes = 'UNKNOWN' @classmethod def deserialize_safe(cls, byts, protocol_version): @@ -854,7 +853,7 @@ def cql_parameterized_type(cls): class UserType(TupleType): - typename = "org.apache.cassandra.db.marshal.UserType" + typename = "'org.apache.cassandra.db.marshal.UserType'" _cache = {} _module = sys.modules[__name__] @@ -957,7 +956,8 @@ def _make_udt_tuple_type(cls, name, field_names): class CompositeType(_ParameterizedType): - typename = "org.apache.cassandra.db.marshal.CompositeType" + typename = "'org.apache.cassandra.db.marshal.CompositeType'" + num_subtypes = 'UNKNOWN' @classmethod def cql_parameterized_type(cls): @@ -985,13 +985,8 @@ def deserialize_safe(cls, byts, protocol_version): return tuple(result) -class DynamicCompositeType(_ParameterizedType): - typename = "org.apache.cassandra.db.marshal.DynamicCompositeType" - - @classmethod - def cql_parameterized_type(cls): - sublist = ', '.join('%s=>%s' % (alias, typ.cass_parameterized_type(full=True)) for alias, typ in zip(cls.fieldnames, cls.subtypes)) - return "'%s(%s)'" % (cls.typename, sublist) +class DynamicCompositeType(CompositeType): + typename = "'org.apache.cassandra.db.marshal.DynamicCompositeType'" class ColumnToCollectionType(_ParameterizedType): @@ -1000,11 +995,12 @@ class ColumnToCollectionType(_ParameterizedType): Cassandra includes this. We don't actually need or want the extra information. """ - typename = "org.apache.cassandra.db.marshal.ColumnToCollectionType" + typename = "'org.apache.cassandra.db.marshal.ColumnToCollectionType'" + num_subtypes = 'UNKNOWN' class ReversedType(_ParameterizedType): - typename = "org.apache.cassandra.db.marshal.ReversedType" + typename = "'org.apache.cassandra.db.marshal.ReversedType'" num_subtypes = 1 @classmethod diff --git a/cassandra/encoder.py b/cassandra/encoder.py index 98d562d1bc..6d8b6ce8a2 100644 --- a/cassandra/encoder.py +++ b/cassandra/encoder.py @@ -40,7 +40,8 @@ def cql_quote(term): # The ordering of this method is important for the result of this method to # be a native str type (for both Python 2 and 3) - if isinstance(term, str): + # Handle quoting of native str and bool types + if isinstance(term, (str, bool)): return "'%s'" % str(term).replace("'", "''") # This branch of the if statement will only be used by Python 2 to catch # unicode strings, text_type is used to prevent type errors with Python 3. diff --git a/cassandra/io/eventletreactor.py b/cassandra/io/eventletreactor.py index cf1616d45b..dfaea8bfb4 100644 --- a/cassandra/io/eventletreactor.py +++ b/cassandra/io/eventletreactor.py @@ -16,10 +16,13 @@ # Originally derived from MagnetoDB source: # https://github.com/stackforge/magnetodb/blob/2015.1.0b1/magnetodb/common/cassandra/io/eventletreactor.py +from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL import eventlet from eventlet.green import socket +import ssl from eventlet.queue import Queue import logging +import os from threading import Event import time @@ -31,6 +34,15 @@ log = logging.getLogger(__name__) +def is_timeout(err): + return ( + err in (EINPROGRESS, EALREADY, EWOULDBLOCK) or + (err == EINVAL and os.name in ('nt', 'ce')) or + (isinstance(err, ssl.SSLError) and err.args[0] == 'timed out') or + isinstance(err, socket.timeout) + ) + + class EventletConnection(Connection): """ An implementation of :class:`.Connection` that utilizes ``eventlet``. @@ -133,6 +145,8 @@ def handle_read(self): buf = self._socket.recv(self.in_buffer_size) self._iobuf.write(buf) except socket.error as err: + if is_timeout(err): + continue log.debug("Exception during socket recv for %s: %s", self, err) self.defunct(err) diff --git a/cassandra/io/geventreactor.py b/cassandra/io/geventreactor.py index bf0a4cc181..65572a664c 100644 --- a/cassandra/io/geventreactor.py +++ b/cassandra/io/geventreactor.py @@ -18,16 +18,26 @@ import gevent.ssl import logging +import os import time from six.moves import range +from errno import EINVAL + from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager log = logging.getLogger(__name__) +def is_timeout(err): + return ( + (err == EINVAL and os.name in ('nt', 'ce')) or + isinstance(err, socket.timeout) + ) + + class GeventConnection(Connection): """ An implementation of :class:`.Connection` that utilizes ``gevent``. @@ -121,9 +131,11 @@ def handle_read(self): buf = self._socket.recv(self.in_buffer_size) self._iobuf.write(buf) except socket.error as err: - log.debug("Exception in read for %s: %s", self, err) - self.defunct(err) - return # leave the read loop + if not is_timeout(err): + log.debug("Exception in read for %s: %s", self, err) + self.defunct(err) + return # leave the read loop + continue if self._iobuf.tell(): self.process_io_buffer() diff --git a/cassandra/io/libevreactor.py b/cassandra/io/libevreactor.py index 39f871a135..a3e96a9a03 100644 --- a/cassandra/io/libevreactor.py +++ b/cassandra/io/libevreactor.py @@ -102,10 +102,10 @@ def maybe_start(self): def _run_loop(self): while True: - self._loop.start() + end_condition = self._loop.start() # there are still active watchers, no deadlock with self._lock: - if not self._shutdown and self._live_conns: + if not self._shutdown and (end_condition or self._live_conns): log.debug("Restarting event loop") continue else: @@ -121,7 +121,10 @@ def _cleanup(self): for conn in self._live_conns | self._new_conns | self._closed_conns: conn.close() - map(lambda w: w.stop(), (w for w in (conn._write_watcher, conn._read_watcher) if w)) + if conn._write_watcher: + conn._write_watcher.stop() + if conn._read_watcher: + conn._read_watcher.stop() self.notify() # wake the timer watcher log.debug("Waiting for event loop thread to join...") @@ -132,6 +135,7 @@ def _cleanup(self): "Please call Cluster.shutdown() to avoid this.") log.debug("Event loop thread was joined") + self._loop = None def add_timer(self, timer): self._timers.add_timer(timer) diff --git a/cassandra/metadata.py b/cassandra/metadata.py index dedaa2de7b..1cd801eed2 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -1058,11 +1058,14 @@ def is_cql_compatible(self): """ comparator = getattr(self, 'comparator', None) if comparator: + # no such thing as DCT in CQL + incompatible = issubclass(self.comparator, types.DynamicCompositeType) + # no compact storage with more than one column beyond PK if there # are clustering columns - incompatible = (self.is_compact_storage and - len(self.columns) > len(self.primary_key) + 1 and - len(self.clustering_key) >= 1) + incompatible |= (self.is_compact_storage and + len(self.columns) > len(self.primary_key) + 1 and + len(self.clustering_key) >= 1) return not incompatible return True @@ -1774,9 +1777,12 @@ def _build_table_metadata(self, row, col_rows=None, trigger_rows=None): comparator = types.lookup_casstype(row["comparator"]) table_meta.comparator = comparator - is_dct_comparator = issubclass(comparator, types.DynamicCompositeType) - is_composite_comparator = issubclass(comparator, types.CompositeType) - column_name_types = comparator.subtypes if is_composite_comparator else (comparator,) + if issubclass(comparator, types.CompositeType): + column_name_types = comparator.subtypes + is_composite_comparator = True + else: + column_name_types = (comparator,) + is_composite_comparator = False num_column_name_components = len(column_name_types) last_col = column_name_types[-1] @@ -1790,8 +1796,7 @@ def _build_table_metadata(self, row, col_rows=None, trigger_rows=None): if column_aliases is not None: column_aliases = json.loads(column_aliases) - - if not column_aliases: # json load failed or column_aliases empty PYTHON-562 + else: column_aliases = [r.get('column_name') for r in clustering_rows] if is_composite_comparator: @@ -1814,10 +1819,10 @@ def _build_table_metadata(self, row, col_rows=None, trigger_rows=None): # Some thrift tables define names in composite types (see PYTHON-192) if not column_aliases and hasattr(comparator, 'fieldnames'): - column_aliases = filter(None, comparator.fieldnames) + column_aliases = comparator.fieldnames else: is_compact = True - if column_aliases or not col_rows or is_dct_comparator: + if column_aliases or not col_rows: has_value = True clustering_size = num_column_name_components else: @@ -1862,7 +1867,7 @@ def _build_table_metadata(self, row, col_rows=None, trigger_rows=None): if len(column_aliases) > i: column_name = column_aliases[i] else: - column_name = "column%d" % (i + 1) + column_name = "column%d" % i data_type = column_name_types[i] cql_type = _cql_from_cass_type(data_type) diff --git a/cassandra/metrics.py b/cassandra/metrics.py index d0c5b9e39c..cf1f25c15d 100644 --- a/cassandra/metrics.py +++ b/cassandra/metrics.py @@ -111,14 +111,10 @@ class Metrics(object): the driver currently has open. """ - _stats_counter = 0 - def __init__(self, cluster_proxy): log.debug("Starting metric capture") - self.stats_name = 'cassandra-{0}'.format(str(self._stats_counter)) - Metrics._stats_counter += 1 - self.stats = scales.collection(self.stats_name, + self.stats = scales.collection('/cassandra', scales.PmfStat('request_timer'), scales.IntStat('connection_errors'), scales.IntStat('write_timeouts'), @@ -136,11 +132,6 @@ def __init__(self, cluster_proxy): scales.Stat('open_connections', lambda: sum(sum(p.open_count for p in s._pools.values()) for s in cluster_proxy.sessions))) - # TODO, to be removed in 4.0 - # /cassandra contains the metrics of the first cluster registered - if 'cassandra' not in scales._Stats.stats: - scales._Stats.stats['cassandra'] = scales._Stats.stats[self.stats_name] - self.request_timer = self.stats.request_timer self.connection_errors = self.stats.connection_errors self.write_timeouts = self.stats.write_timeouts @@ -173,27 +164,3 @@ def on_ignore(self): def on_retry(self): self.stats.retries += 1 - - def get_stats(self): - """ - Returns the metrics for the registered cluster instance. - """ - return scales.getStats()[self.stats_name] - - def set_stats_name(self, stats_name): - """ - Set the metrics stats name. - The stats_name is a string used to access the metris through scales: scales.getStats()[] - Default is 'cassandra-'. - """ - - if self.stats_name == stats_name: - return - - if stats_name in scales._Stats.stats: - raise ValueError('"{0}" already exists in stats.'.format(stats_name)) - - stats = scales._Stats.stats[self.stats_name] - del scales._Stats.stats[self.stats_name] - self.stats_name = stats_name - scales._Stats.stats[self.stats_name] = stats diff --git a/cassandra/numpy_parser.pyx b/cassandra/numpy_parser.pyx index ed755d00a4..1334e747c4 100644 --- a/cassandra/numpy_parser.pyx +++ b/cassandra/numpy_parser.pyx @@ -13,7 +13,7 @@ # limitations under the License. """ -This module provides an optional protocol parser that returns +This module provider an optional protocol parser that returns NumPy arrays. ============================================================================= @@ -25,7 +25,7 @@ as numpy is an optional dependency. include "ioutils.pyx" cimport cython -from libc.stdint cimport uint64_t, uint8_t +from libc.stdint cimport uint64_t from cpython.ref cimport Py_INCREF, PyObject from cassandra.bytesio cimport BytesIOReader @@ -35,6 +35,7 @@ from cassandra import cqltypes from cassandra.util import is_little_endian import numpy as np +# import pandas as pd cdef extern from "numpyFlags.h": # Include 'numpyFlags.h' into the generated C code to disable the @@ -51,13 +52,11 @@ ctypedef struct ArrDesc: Py_uintptr_t buf_ptr int stride # should be large enough as we allocate contiguous arrays int is_object - Py_uintptr_t mask_ptr arrDescDtype = np.dtype( [ ('buf_ptr', np.uintp) , ('stride', np.dtype('i')) , ('is_object', np.dtype('i')) - , ('mask_ptr', np.uintp) ], align=True) _cqltype_to_numpy = { @@ -71,7 +70,6 @@ _cqltype_to_numpy = { obj_dtype = np.dtype('O') -cdef uint8_t mask_true = 0x01 cdef class NumpyParser(ColumnParser): """Decode a ResultMessage into a bunch of NumPy arrays""" @@ -118,11 +116,7 @@ def make_arrays(ParseDesc desc, array_size): arr = make_array(coltype, array_size) array_descs[i]['buf_ptr'] = arr.ctypes.data array_descs[i]['stride'] = arr.strides[0] - array_descs[i]['is_object'] = arr.dtype is obj_dtype - try: - array_descs[i]['mask_ptr'] = arr.mask.ctypes.data - except AttributeError: - array_descs[i]['mask_ptr'] = 0 + array_descs[i]['is_object'] = coltype not in _cqltype_to_numpy arrays.append(arr) return array_descs, arrays @@ -132,12 +126,8 @@ def make_array(coltype, array_size): """ Allocate a new NumPy array of the given column type and size. """ - try: - a = np.ma.empty((array_size,), dtype=_cqltype_to_numpy[coltype]) - a.mask = np.zeros((array_size,), dtype=np.bool) - except KeyError: - a = np.empty((array_size,), dtype=obj_dtype) - return a + dtype = _cqltype_to_numpy.get(coltype, obj_dtype) + return np.empty((array_size,), dtype=dtype) #### Parse rows into NumPy arrays @@ -150,6 +140,7 @@ cdef inline int unpack_row( cdef Py_ssize_t i, rowsize = desc.rowsize cdef ArrDesc arr cdef Deserializer deserializer + for i in range(rowsize): get_buf(reader, &buf) arr = arrays[i] @@ -159,14 +150,13 @@ cdef inline int unpack_row( val = from_binary(deserializer, &buf, desc.protocol_version) Py_INCREF(val) ( arr.buf_ptr)[0] = val - elif buf.size >= 0: - memcpy( arr.buf_ptr, buf.ptr, buf.size) + elif buf.size < 0: + raise ValueError("Cannot handle NULL value") else: - memcpy(arr.mask_ptr, &mask_true, 1) + memcpy( arr.buf_ptr, buf.ptr, buf.size) # Update the pointer into the array for the next time arrays[i].buf_ptr += arr.stride - arrays[i].mask_ptr += 1 return 0 diff --git a/cassandra/protocol.py b/cassandra/protocol.py index e9e4450f5a..4c63d557d5 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -126,7 +126,7 @@ def __init__(self, code, message, info): self.info = info @classmethod - def recv_body(cls, f, *args): + def recv_body(cls, f, protocol_version, user_type_map): code = read_int(f) msg = read_string(f) subcls = error_classes.get(code, cls) @@ -378,7 +378,7 @@ class ReadyMessage(_MessageType): name = 'READY' @classmethod - def recv_body(cls, *args): + def recv_body(cls, f, protocol_version, user_type_map): return cls() @@ -390,7 +390,7 @@ def __init__(self, authenticator): self.authenticator = authenticator @classmethod - def recv_body(cls, f, *args): + def recv_body(cls, f, protocol_version, user_type_map): authname = read_string(f) return cls(authenticator=authname) @@ -422,7 +422,7 @@ def __init__(self, challenge): self.challenge = challenge @classmethod - def recv_body(cls, f, *args): + def recv_body(cls, f, protocol_version, user_type_map): return cls(read_binary_longstring(f)) @@ -445,7 +445,7 @@ def __init__(self, token): self.token = token @classmethod - def recv_body(cls, f, *args): + def recv_body(cls, f, protocol_version, user_type_map): return cls(read_longstring(f)) @@ -466,7 +466,7 @@ def __init__(self, cql_versions, options): self.options = options @classmethod - def recv_body(cls, f, *args): + def recv_body(cls, f, protocol_version, user_type_map): options = read_stringmultimap(f) cql_versions = options.pop('CQL_VERSION') return cls(cql_versions=cql_versions, options=options) @@ -474,7 +474,7 @@ def recv_body(cls, f, *args): # used for QueryMessage and ExecuteMessage _VALUES_FLAG = 0x01 -_SKIP_METADATA_FLAG = 0x02 +_SKIP_METADATA_FLAG = 0x01 _PAGE_SIZE_FLAG = 0x04 _WITH_PAGING_STATE_FLAG = 0x08 _WITH_SERIAL_CONSISTENCY_FLAG = 0x10 @@ -577,14 +577,14 @@ def __init__(self, kind, results, paging_state=None): self.paging_state = paging_state @classmethod - def recv_body(cls, f, protocol_version, user_type_map, result_metadata): + def recv_body(cls, f, protocol_version, user_type_map): kind = read_int(f) paging_state = None if kind == RESULT_KIND_VOID: results = None elif kind == RESULT_KIND_ROWS: paging_state, results = cls.recv_results_rows( - f, protocol_version, user_type_map, result_metadata) + f, protocol_version, user_type_map) elif kind == RESULT_KIND_SET_KEYSPACE: ksname = read_string(f) results = ksname @@ -597,9 +597,8 @@ def recv_body(cls, f, protocol_version, user_type_map, result_metadata): return cls(kind, results, paging_state) @classmethod - def recv_results_rows(cls, f, protocol_version, user_type_map, result_metadata): + def recv_results_rows(cls, f, protocol_version, user_type_map): paging_state, column_metadata = cls.recv_results_metadata(f, user_type_map) - column_metadata = column_metadata or result_metadata rowcount = read_int(f) rows = [cls.recv_row(f, len(column_metadata)) for _ in range(rowcount)] colnames = [c[2] for c in column_metadata] @@ -608,29 +607,24 @@ def recv_results_rows(cls, f, protocol_version, user_type_map, result_metadata): tuple(ctype.from_binary(val, protocol_version) for ctype, val in zip(coltypes, row)) for row in rows] - return paging_state, (colnames, parsed_rows) + return (paging_state, (colnames, parsed_rows)) @classmethod def recv_results_prepared(cls, f, protocol_version, user_type_map): query_id = read_binary_string(f) - bind_metadata, pk_indexes, result_metadata = cls.recv_prepared_metadata(f, protocol_version, user_type_map) - return query_id, bind_metadata, pk_indexes, result_metadata + column_metadata, pk_indexes = cls.recv_prepared_metadata(f, protocol_version, user_type_map) + return (query_id, column_metadata, pk_indexes) @classmethod def recv_results_metadata(cls, f, user_type_map): flags = read_int(f) + glob_tblspec = bool(flags & cls._FLAGS_GLOBAL_TABLES_SPEC) colcount = read_int(f) if flags & cls._HAS_MORE_PAGES_FLAG: paging_state = read_binary_longstring(f) else: paging_state = None - - no_meta = bool(flags & cls._NO_METADATA_FLAG) - if no_meta: - return paging_state, [] - - glob_tblspec = bool(flags & cls._FLAGS_GLOBAL_TABLES_SPEC) if glob_tblspec: ksname = read_string(f) cfname = read_string(f) @@ -650,17 +644,17 @@ def recv_results_metadata(cls, f, user_type_map): @classmethod def recv_prepared_metadata(cls, f, protocol_version, user_type_map): flags = read_int(f) + glob_tblspec = bool(flags & cls._FLAGS_GLOBAL_TABLES_SPEC) colcount = read_int(f) pk_indexes = None if protocol_version >= 4: num_pk_indexes = read_int(f) pk_indexes = [read_short(f) for _ in range(num_pk_indexes)] - glob_tblspec = bool(flags & cls._FLAGS_GLOBAL_TABLES_SPEC) if glob_tblspec: ksname = read_string(f) cfname = read_string(f) - bind_metadata = [] + column_metadata = [] for _ in range(colcount): if glob_tblspec: colksname = ksname @@ -670,13 +664,8 @@ def recv_prepared_metadata(cls, f, protocol_version, user_type_map): colcfname = read_string(f) colname = read_string(f) coltype = cls.read_type(f, user_type_map) - bind_metadata.append(ColumnMetadata(colksname, colcfname, colname, coltype)) - - if protocol_version >= 2: - _, result_metadata = cls.recv_results_metadata(f, user_type_map) - return bind_metadata, pk_indexes, result_metadata - else: - return bind_metadata, pk_indexes, None + column_metadata.append(ColumnMetadata(colksname, colcfname, colname, coltype)) + return column_metadata, pk_indexes @classmethod def recv_results_schema_change(cls, f, protocol_version): @@ -738,7 +727,7 @@ class ExecuteMessage(_MessageType): def __init__(self, query_id, query_params, consistency_level, serial_consistency_level=None, fetch_size=None, - paging_state=None, timestamp=None, skip_meta=False): + paging_state=None, timestamp=None): self.query_id = query_id self.query_params = query_params self.consistency_level = consistency_level @@ -746,7 +735,6 @@ def __init__(self, query_id, query_params, consistency_level, self.fetch_size = fetch_size self.paging_state = paging_state self.timestamp = timestamp - self.skip_meta = skip_meta def send_body(self, f, protocol_version): write_string(f, self.query_id) @@ -780,8 +768,6 @@ def send_body(self, f, protocol_version): raise UnsupportedOperation( "Protocol-level timestamps may only be used with protocol version " "3 or higher. Consider setting Cluster.protocol_version to 3.") - if self.skip_meta: - flags |= _SKIP_METADATA_FLAG write_byte(f, flags) write_short(f, len(self.query_params)) for param in self.query_params: @@ -796,7 +782,6 @@ def send_body(self, f, protocol_version): write_long(f, self.timestamp) - class BatchMessage(_MessageType): opcode = 0x0D name = 'BATCH' @@ -866,7 +851,7 @@ def __init__(self, event_type, event_args): self.event_args = event_args @classmethod - def recv_body(cls, f, protocol_version, *args): + def recv_body(cls, f, protocol_version, user_type_map): event_type = read_string(f).upper() if event_type in known_event_types: read_method = getattr(cls, 'recv_' + event_type.lower()) @@ -975,7 +960,7 @@ def _write_header(f, version, flags, stream_id, opcode, length): @classmethod def decode_message(cls, protocol_version, user_type_map, stream_id, flags, opcode, body, - decompressor, result_metadata): + decompressor): """ Decodes a native protocol message body @@ -1017,7 +1002,7 @@ def decode_message(cls, protocol_version, user_type_map, stream_id, flags, opcod log.warning("Unknown protocol flags set: %02x. May cause problems.", flags) msg_class = cls.message_types_by_opcode[opcode] - msg = msg_class.recv_body(body, protocol_version, user_type_map, result_metadata) + msg = msg_class.recv_body(body, protocol_version, user_type_map) msg.stream_id = stream_id msg.trace_id = trace_id msg.custom_payload = custom_payload diff --git a/cassandra/query.py b/cassandra/query.py index 65cb6ba9e0..8662f0bda4 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -219,7 +219,8 @@ class Statement(object): _routing_key = None def __init__(self, retry_policy=None, consistency_level=None, routing_key=None, - serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, custom_payload=None): + serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, + custom_payload=None): if retry_policy and not hasattr(retry_policy, 'on_read_timeout'): # just checking one method to detect positional parameter errors raise ValueError('retry_policy should implement cassandra.policies.RetryPolicy') self.retry_policy = retry_policy @@ -361,34 +362,36 @@ class PreparedStatement(object): may affect performance (as the operation requires a network roundtrip). """ - column_metadata = None #TODO: make this bind_metadata in next major - consistency_level = None - custom_payload = None - fetch_size = FETCH_SIZE_UNSET - keyspace = None # change to prepared_keyspace in major release - protocol_version = None + column_metadata = None query_id = None query_string = None - result_metadata = None + keyspace = None # change to prepared_keyspace in major release + routing_key_indexes = None _routing_key_index_set = None + + consistency_level = None serial_consistency_level = None + protocol_version = None + + fetch_size = FETCH_SIZE_UNSET + + custom_payload = None + def __init__(self, column_metadata, query_id, routing_key_indexes, query, - keyspace, protocol_version, result_metadata): + keyspace, protocol_version): self.column_metadata = column_metadata self.query_id = query_id self.routing_key_indexes = routing_key_indexes self.query_string = query self.keyspace = keyspace self.protocol_version = protocol_version - self.result_metadata = result_metadata @classmethod - def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata, - query, prepared_keyspace, protocol_version, result_metadata): + def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata, query, prepared_keyspace, protocol_version): if not column_metadata: - return PreparedStatement(column_metadata, query_id, None, query, prepared_keyspace, protocol_version, result_metadata) + return PreparedStatement(column_metadata, query_id, None, query, prepared_keyspace, protocol_version) if pk_indexes: routing_key_indexes = pk_indexes @@ -413,7 +416,7 @@ def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata, pass # statement; just leave routing_key_indexes as None return PreparedStatement(column_metadata, query_id, routing_key_indexes, - query, prepared_keyspace, protocol_version, result_metadata) + query, prepared_keyspace, protocol_version) def bind(self, values): """ diff --git a/cassandra/row_parser.pyx b/cassandra/row_parser.pyx index 8422d544d3..ec2b83bed7 100644 --- a/cassandra/row_parser.pyx +++ b/cassandra/row_parser.pyx @@ -18,15 +18,13 @@ from cassandra.deserializers import make_deserializers include "ioutils.pyx" def make_recv_results_rows(ColumnParser colparser): - def recv_results_rows(cls, f, int protocol_version, user_type_map, result_metadata): + def recv_results_rows(cls, f, int protocol_version, user_type_map): """ Parse protocol data given as a BytesIO f into a set of columns (e.g. list of tuples) This is used as the recv_results_rows method of (Fast)ResultMessage """ paging_state, column_metadata = cls.recv_results_metadata(f, user_type_map) - column_metadata = column_metadata or result_metadata - colnames = [c[2] for c in column_metadata] coltypes = [c[3] for c in column_metadata] diff --git a/cassandra/type_codes.py b/cassandra/type_codes.py index daf882e46c..2f0ce8f5a0 100644 --- a/cassandra/type_codes.py +++ b/cassandra/type_codes.py @@ -59,3 +59,4 @@ SetType = 0x0022 UserType = 0x0030 TupleType = 0x0031 + diff --git a/cassandra/util.py b/cassandra/util.py index 7f17e85d18..f4bc1b1c94 100644 --- a/cassandra/util.py +++ b/cassandra/util.py @@ -973,8 +973,6 @@ def __eq__(self, other): microsecond=self.nanosecond // Time.MICRO) == other def __lt__(self, other): - if not isinstance(other, Time): - return NotImplemented return self.nanosecond_time < other.nanosecond_time def __repr__(self): @@ -1063,8 +1061,6 @@ def __eq__(self, other): return False def __lt__(self, other): - if not isinstance(other, Date): - return NotImplemented return self.days_from_epoch < other.days_from_epoch def __repr__(self): diff --git a/docs.yaml b/docs.yaml index b337d5dd7b..aa30ed5df3 100644 --- a/docs.yaml +++ b/docs.yaml @@ -6,8 +6,3 @@ sections: prefix: / type: sphinx directory: docs -versions: - - name: 3.5.0 - ref: 3.5.0 -redirects: - - \A\/(.*)/\Z: /\1.html diff --git a/docs/api/cassandra/cqlengine/models.rst b/docs/api/cassandra/cqlengine/models.rst index fd081fb190..d6f3391974 100644 --- a/docs/api/cassandra/cqlengine/models.rst +++ b/docs/api/cassandra/cqlengine/models.rst @@ -32,10 +32,8 @@ Model .. autoattribute:: __keyspace__ - .. attribute:: __default_ttl__ - :annotation: = None - - Will be deprecated in release 4.0. You can set the default ttl by configuring the table ``__options__``. See :ref:`ttl-change` for more details. + .. _ttl-change: + .. autoattribute:: __default_ttl__ .. autoattribute:: __discriminator_value__ diff --git a/docs/api/cassandra/cqlengine/query.rst b/docs/api/cassandra/cqlengine/query.rst index c0c8f285cf..461ec9b969 100644 --- a/docs/api/cassandra/cqlengine/query.rst +++ b/docs/api/cassandra/cqlengine/query.rst @@ -54,14 +54,6 @@ The methods here are used to filter, order, and constrain results. .. automethod:: update -.. autoclass:: BatchQuery - :members: - - .. automethod:: add_query - .. automethod:: execute - -.. autoclass:: ContextQuery - .. autoclass:: DoesNotExist .. autoclass:: MultipleObjectsReturned diff --git a/docs/cqlengine/queryset.rst b/docs/cqlengine/queryset.rst index c9c33932f8..ff328b0ce4 100644 --- a/docs/cqlengine/queryset.rst +++ b/docs/cqlengine/queryset.rst @@ -343,42 +343,6 @@ None means no timeout. Setting the timeout on the model is meaningless and will raise an AssertionError. -.. _ttl-change: - -Default TTL and Per Query TTL -============================= - -Model default TTL now relies on the *default_time_to_live* feature, introduced in Cassandra 2.0. It is not handled anymore in the CQLEngine Model (cassandra-driver >=3.6). You can set the default TTL of a table like this: - - Example: - - .. code-block:: python - - class User(Model): - __options__ = {'default_time_to_live': 20} - - user_id = columns.UUID(primary_key=True) - ... - -You can set TTL per-query if needed. Here are a some examples: - - Example: - - .. code-block:: python - - class User(Model): - __options__ = {'default_time_to_live': 20} - - user_id = columns.UUID(primary_key=True) - ... - - user = User.objects.create(user_id=1) # Default TTL 20 will be set automatically on the server - - user.ttl(30).update(age=21) # Update the TTL to 30 - User.objects.ttl(10).create(user_id=1) # TTL 10 - User(user_id=1, age=21).ttl(10).save() # TTL 10 - - Named Tables =================== diff --git a/docs/getting_started.rst b/docs/getting_started.rst index c7cbc25970..2d9c7ea461 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -179,7 +179,7 @@ Named place-holders use the ``%(name)s`` form: """ INSERT INTO users (name, credits, user_id, username) VALUES (%(name)s, %(credits)s, %(user_id)s, %(name)s) - """, + """ {'name': "John O'Reilly", 'credits': 42, 'user_id': uuid.uuid1()} ) diff --git a/test-requirements.txt b/test-requirements.txt index 500795357c..4c917da6c6 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,7 +1,7 @@ -r requirements.txt scales nose -mock!=1.1.* +mock<=1.0.1 ccm>=2.0 unittest2 PyYAML diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py index bd9fe103cd..62a58896a4 100644 --- a/tests/integration/__init__.py +++ b/tests/integration/__init__.py @@ -23,7 +23,6 @@ import sys import time import traceback -import platform from threading import Event from subprocess import call from itertools import groupby @@ -138,67 +137,14 @@ def _get_cass_version_from_dse(dse_version): CCM_KWARGS['dse_credentials_file'] = DSE_CRED -def get_default_protocol(): - - if CASSANDRA_VERSION >= '2.2': - return 4 - elif CASSANDRA_VERSION >= '2.1': - return 3 - elif CASSANDRA_VERSION >= '2.0': - return 2 - else: - return 1 - - -def get_supported_protocol_versions(): - """ - 1.2 -> 1 - 2.0 -> 2, 1 - 2.1 -> 3, 2, 1 - 2.2 -> 4, 3, 2, 1 - 3.X -> 4, 3 -` """ - if CASSANDRA_VERSION >= '3.0': - return (3, 4) - elif CASSANDRA_VERSION >= '2.2': - return (1, 2, 3, 4) - elif CASSANDRA_VERSION >= '2.1': - return (1, 2, 3) - elif CASSANDRA_VERSION >= '2.0': - return (1, 2) - else: - return (1) - - -def get_unsupported_lower_protocol(): - """ - This is used to determine the lowest protocol version that is NOT - supported by the version of C* running - """ - - if CASSANDRA_VERSION >= '3.0': - return 2 - else: - return None - - -def get_unsupported_upper_protocol(): - """ - This is used to determine the highest protocol version that is NOT - supported by the version of C* running - """ - - if CASSANDRA_VERSION >= '2.2': - return None - if CASSANDRA_VERSION >= '2.1': - return 4 - elif CASSANDRA_VERSION >= '2.0': - return 3 - else: - return None - -default_protocol_version = get_default_protocol() - +if CASSANDRA_VERSION >= '2.2': + default_protocol_version = 4 +elif CASSANDRA_VERSION >= '2.1': + default_protocol_version = 3 +elif CASSANDRA_VERSION >= '2.0': + default_protocol_version = 2 +else: + default_protocol_version = 1 PROTOCOL_VERSION = int(os.getenv('PROTOCOL_VERSION', default_protocol_version)) @@ -211,7 +157,6 @@ def get_unsupported_upper_protocol(): greaterthanorequalcass30 = unittest.skipUnless(CASSANDRA_VERSION >= '3.0', 'Cassandra version 3.0 or greater required') lessthancass30 = unittest.skipUnless(CASSANDRA_VERSION < '3.0', 'Cassandra version less then 3.0 required') dseonly = unittest.skipUnless(DSE_VERSION, "Test is only applicalbe to DSE clusters") -pypy = unittest.skipUnless(platform.python_implementation() == "PyPy", "Test is skipped unless it's on PyPy") def wait_for_node_socket(node, timeout): @@ -296,7 +241,6 @@ def use_cluster(cluster_name, nodes, ipformat=None, start=True, workloads=[]): log.debug("Using external CCM cluster {0}".format(CCM_CLUSTER.name)) else: log.debug("Using unnamed external cluster") - setup_keyspace(ipformat=ipformat, wait=False) return if is_current_cluster(cluster_name, nodes): @@ -443,10 +387,9 @@ def drop_keyspace_shutdown_cluster(keyspace_name, session, cluster): cluster.shutdown() -def setup_keyspace(ipformat=None, wait=True): +def setup_keyspace(ipformat=None): # wait for nodes to startup - if wait: - time.sleep(10) + time.sleep(10) if not ipformat: cluster = Cluster(protocol_version=PROTOCOL_VERSION) @@ -538,8 +481,8 @@ def create_keyspace(cls, rf): execute_with_long_wait_retry(cls.session, ddl) @classmethod - def common_setup(cls, rf, keyspace_creation=True, create_class_table=False, metrics=False): - cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION, metrics_enabled=metrics) + def common_setup(cls, rf, keyspace_creation=True, create_class_table=False): + cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION) cls.session = cls.cluster.connect() cls.ks_name = cls.__name__.lower() if keyspace_creation: @@ -592,7 +535,6 @@ def get_message_count(self, level, sub_string): count+=1 return count - class BasicExistingKeyspaceUnitTestCase(BasicKeyspaceUnitTestCase): """ This is basic unit test defines class level teardown and setup methods. It assumes that keyspace is already defined, or created as part of the test. @@ -647,7 +589,7 @@ class BasicSharedKeyspaceUnitTestCaseWTable(BasicSharedKeyspaceUnitTestCase): """ @classmethod def setUpClass(self): - self.common_setup(3, True, True, True) + self.common_setup(2, True) class BasicSharedKeyspaceUnitTestCaseRF3(BasicSharedKeyspaceUnitTestCase): diff --git a/tests/integration/cqlengine/__init__.py b/tests/integration/cqlengine/__init__.py index 3f163ded64..e61698e82c 100644 --- a/tests/integration/cqlengine/__init__.py +++ b/tests/integration/cqlengine/__init__.py @@ -96,7 +96,7 @@ def wrapped_function(*args, **kwargs): else: test_case = args[0] # Check to see if the count is what you expect - test_case.assertEqual(count.get_counter(), expected, msg="Expected number of cassandra.cqlengine.connection.execute calls ({0}) doesn't match actual number invoked ({1})".format(expected, count.get_counter())) + test_case.assertEqual(count.get_counter(), expected, msg="Expected number of cassandra.cqlengine.connection.execute calls doesn't match actual number invoked Expected: {0}, Invoked {1}".format(count.get_counter(), expected)) return to_return # Name of the wrapped function must match the original or unittest will error out. wrapped_function.__name__ = fn.__name__ diff --git a/tests/integration/cqlengine/columns/test_validation.py b/tests/integration/cqlengine/columns/test_validation.py index 4980415208..0480fe43e8 100644 --- a/tests/integration/cqlengine/columns/test_validation.py +++ b/tests/integration/cqlengine/columns/test_validation.py @@ -17,14 +17,12 @@ except ImportError: import unittest # noqa -import sys from datetime import datetime, timedelta, date, tzinfo from decimal import Decimal as D from uuid import uuid4, uuid1 from cassandra import InvalidRequest from cassandra.cqlengine.columns import TimeUUID -from cassandra.cqlengine.columns import Ascii from cassandra.cqlengine.columns import Text from cassandra.cqlengine.columns import Integer from cassandra.cqlengine.columns import BigInt @@ -339,249 +337,50 @@ def test_default_zero_fields_validate(self): it.validate() -class TestAscii(BaseCassEngTestCase): +class TestText(BaseCassEngTestCase): def test_min_length(self): - """ Test arbitrary minimal lengths requirements. """ - Ascii(min_length=0).validate('') - Ascii(min_length=0).validate(None) - Ascii(min_length=0).validate('kevin') - - Ascii(min_length=1).validate('k') - - Ascii(min_length=5).validate('kevin') - Ascii(min_length=5).validate('kevintastic') - - with self.assertRaises(ValidationError): - Ascii(min_length=1).validate('') - - with self.assertRaises(ValidationError): - Ascii(min_length=1).validate(None) - - with self.assertRaises(ValidationError): - Ascii(min_length=6).validate('') - - with self.assertRaises(ValidationError): - Ascii(min_length=6).validate(None) - - with self.assertRaises(ValidationError): - Ascii(min_length=6).validate('kevin') - - with self.assertRaises(ValueError): - Ascii(min_length=-1) - - def test_max_length(self): - """ Test arbitrary maximal lengths requirements. """ - Ascii(max_length=0).validate('') - Ascii(max_length=0).validate(None) - - Ascii(max_length=1).validate('') - Ascii(max_length=1).validate(None) - Ascii(max_length=1).validate('b') - - Ascii(max_length=5).validate('') - Ascii(max_length=5).validate(None) - Ascii(max_length=5).validate('b') - Ascii(max_length=5).validate('blake') + # not required defaults to 0 + col = Text() + col.validate('') + col.validate('b') + # required defaults to 1 with self.assertRaises(ValidationError): - Ascii(max_length=0).validate('b') - - with self.assertRaises(ValidationError): - Ascii(max_length=5).validate('blaketastic') - - with self.assertRaises(ValueError): - Ascii(max_length=-1) - - def test_length_range(self): - Ascii(min_length=0, max_length=0) - Ascii(min_length=0, max_length=1) - Ascii(min_length=10, max_length=10) - Ascii(min_length=10, max_length=11) - - with self.assertRaises(ValueError): - Ascii(min_length=10, max_length=9) - - with self.assertRaises(ValueError): - Ascii(min_length=1, max_length=0) - - def test_type_checking(self): - Ascii().validate('string') - Ascii().validate(u'unicode') - Ascii().validate(bytearray('bytearray', encoding='ascii')) - - with self.assertRaises(ValidationError): - Ascii().validate(5) - - with self.assertRaises(ValidationError): - Ascii().validate(True) - - Ascii().validate("!#$%&\'()*+,-./") - - with self.assertRaises(ValidationError): - Ascii().validate('Beyonc' + chr(233)) - - if sys.version_info < (3, 1): - with self.assertRaises(ValidationError): - Ascii().validate(u'Beyonc' + unichr(233)) - - def test_unaltering_validation(self): - """ Test the validation step doesn't re-interpret values. """ - self.assertEqual(Ascii().validate(''), '') - self.assertEqual(Ascii().validate(None), None) - self.assertEqual(Ascii().validate('yo'), 'yo') - - def test_non_required_validation(self): - """ Tests that validation is ok on none and blank values if required is False. """ - Ascii().validate('') - Ascii().validate(None) - - def test_required_validation(self): - """ Tests that validation raise on none and blank values if value required. """ - Ascii(required=True).validate('k') - - with self.assertRaises(ValidationError): - Ascii(required=True).validate('') - - with self.assertRaises(ValidationError): - Ascii(required=True).validate(None) - - # With min_length set. - Ascii(required=True, min_length=0).validate('k') - Ascii(required=True, min_length=1).validate('k') - - with self.assertRaises(ValidationError): - Ascii(required=True, min_length=2).validate('k') - - # With max_length set. - Ascii(required=True, max_length=1).validate('k') - - with self.assertRaises(ValidationError): - Ascii(required=True, max_length=2).validate('kevin') - - with self.assertRaises(ValueError): - Ascii(required=True, max_length=0) - - -class TestText(BaseCassEngTestCase): + Text(required=True).validate('') - def test_min_length(self): - """ Test arbitrary minimal lengths requirements. """ + #test arbitrary lengths Text(min_length=0).validate('') - Text(min_length=0).validate(None) - Text(min_length=0).validate('blake') - - Text(min_length=1).validate('b') - Text(min_length=5).validate('blake') Text(min_length=5).validate('blaketastic') - - with self.assertRaises(ValidationError): - Text(min_length=1).validate('') - - with self.assertRaises(ValidationError): - Text(min_length=1).validate(None) - - with self.assertRaises(ValidationError): - Text(min_length=6).validate('') - - with self.assertRaises(ValidationError): - Text(min_length=6).validate(None) - with self.assertRaises(ValidationError): Text(min_length=6).validate('blake') - with self.assertRaises(ValueError): - Text(min_length=-1) - def test_max_length(self): - """ Test arbitrary maximal lengths requirements. """ - Text(max_length=0).validate('') - Text(max_length=0).validate(None) - - Text(max_length=1).validate('') - Text(max_length=1).validate(None) - Text(max_length=1).validate('b') - Text(max_length=5).validate('') - Text(max_length=5).validate(None) - Text(max_length=5).validate('b') Text(max_length=5).validate('blake') - - with self.assertRaises(ValidationError): - Text(max_length=0).validate('b') - with self.assertRaises(ValidationError): Text(max_length=5).validate('blaketastic') - with self.assertRaises(ValueError): - Text(max_length=-1) - - def test_length_range(self): - Text(min_length=0, max_length=0) - Text(min_length=0, max_length=1) - Text(min_length=10, max_length=10) - Text(min_length=10, max_length=11) - - with self.assertRaises(ValueError): - Text(min_length=10, max_length=9) - - with self.assertRaises(ValueError): - Text(min_length=1, max_length=0) - def test_type_checking(self): Text().validate('string') Text().validate(u'unicode') Text().validate(bytearray('bytearray', encoding='ascii')) + with self.assertRaises(ValidationError): + Text(required=True).validate(None) + with self.assertRaises(ValidationError): Text().validate(5) with self.assertRaises(ValidationError): Text().validate(True) - Text().validate("!#$%&\'()*+,-./") - Text().validate('Beyonc' + chr(233)) - if sys.version_info < (3, 1): - Text().validate(u'Beyonc' + unichr(233)) - - def test_unaltering_validation(self): - """ Test the validation step doesn't re-interpret values. """ - self.assertEqual(Text().validate(''), '') - self.assertEqual(Text().validate(None), None) - self.assertEqual(Text().validate('yo'), 'yo') - def test_non_required_validation(self): """ Tests that validation is ok on none and blank values if required is False """ Text().validate('') Text().validate(None) - def test_required_validation(self): - """ Tests that validation raise on none and blank values if value required. """ - Text(required=True).validate('b') - - with self.assertRaises(ValidationError): - Text(required=True).validate('') - - with self.assertRaises(ValidationError): - Text(required=True).validate(None) - - # With min_length set. - Text(required=True, min_length=0).validate('b') - Text(required=True, min_length=1).validate('b') - - with self.assertRaises(ValidationError): - Text(required=True, min_length=2).validate('b') - - # With max_length set. - Text(required=True, max_length=1).validate('b') - - with self.assertRaises(ValidationError): - Text(required=True, max_length=2).validate('blake') - - with self.assertRaises(ValueError): - Text(required=True, max_length=0) - class TestExtraFieldsRaiseException(BaseCassEngTestCase): class TestModel(Model): diff --git a/tests/integration/cqlengine/model/test_class_construction.py b/tests/integration/cqlengine/model/test_class_construction.py index e447056376..8147e41079 100644 --- a/tests/integration/cqlengine/model/test_class_construction.py +++ b/tests/integration/cqlengine/model/test_class_construction.py @@ -47,30 +47,9 @@ class TestModel(Model): inst = TestModel() self.assertHasAttr(inst, 'id') self.assertHasAttr(inst, 'text') - self.assertIsNotNone(inst.id) + self.assertIsNone(inst.id) self.assertIsNone(inst.text) - def test_values_on_instantiation(self): - """ - Tests defaults and user-provided values on instantiation. - """ - - class TestPerson(Model): - first_name = columns.Text(primary_key=True, default='kevin') - last_name = columns.Text(default='deldycke') - - # Check that defaults are available at instantiation. - inst1 = TestPerson() - self.assertHasAttr(inst1, 'first_name') - self.assertHasAttr(inst1, 'last_name') - self.assertEqual(inst1.first_name, 'kevin') - self.assertEqual(inst1.last_name, 'deldycke') - - # Check that values on instantiation overrides defaults. - inst2 = TestPerson(first_name='bob', last_name='joe') - self.assertEqual(inst2.first_name, 'bob') - self.assertEqual(inst2.last_name, 'joe') - def test_db_map(self): """ Tests that the db_map is properly defined diff --git a/tests/integration/cqlengine/model/test_model.py b/tests/integration/cqlengine/model/test_model.py index b31b8d5aee..e46698ff75 100644 --- a/tests/integration/cqlengine/model/test_model.py +++ b/tests/integration/cqlengine/model/test_model.py @@ -22,8 +22,7 @@ from cassandra.cqlengine.management import sync_table, drop_table, create_keyspace_simple, drop_keyspace from cassandra.cqlengine import models from cassandra.cqlengine.models import Model, ModelDefinitionException -from uuid import uuid1 -from tests.integration import pypy + class TestModel(unittest.TestCase): """ Tests the non-io functionality of models """ @@ -173,37 +172,4 @@ class IllegalFilterColumnModel(Model): my_primary_key = columns.Integer(primary_key=True) filter = columns.Text() -@pypy -class ModelOverWriteTest(unittest.TestCase): - - def test_model_over_write(self): - """ - Test to ensure overwriting of primary keys in model inheritance is allowed - - This is currently only an issue in PyPy. When PYTHON-504 is introduced this should - be updated error out and warn the user - - @since 3.6.0 - @jira_ticket PYTHON-576 - @expected_result primary keys can be overwritten via inheritance - - @test_category object_mapper - """ - class TimeModelBase(Model): - uuid = columns.TimeUUID(primary_key=True) - - class DerivedTimeModel(TimeModelBase): - __table_name__ = 'derived_time' - uuid = columns.TimeUUID(primary_key=True, partition_key=True) - value = columns.Text(required=False) - - # In case the table already exists in keyspace - drop_table(DerivedTimeModel) - - sync_table(DerivedTimeModel) - uuid_value = uuid1() - uuid_value2 = uuid1() - DerivedTimeModel.create(uuid=uuid_value, value="first") - DerivedTimeModel.create(uuid=uuid_value2, value="second") - DerivedTimeModel.objects.filter(uuid=uuid_value) diff --git a/tests/integration/cqlengine/model/test_model_io.py b/tests/integration/cqlengine/model/test_model_io.py index c5fd5e37ca..3faf62febc 100644 --- a/tests/integration/cqlengine/model/test_model_io.py +++ b/tests/integration/cqlengine/model/test_model_io.py @@ -38,6 +38,8 @@ from tests.integration.cqlengine import DEFAULT_KEYSPACE + + class TestModel(Model): id = columns.UUID(primary_key=True, default=lambda: uuid4()) @@ -70,7 +72,7 @@ def tearDownClass(cls): def test_model_save_and_load(self): """ - Tests that models can be saved and retrieved, using the create method. + Tests that models can be saved and retrieved """ tm = TestModel.create(count=8, text='123456789') self.assertIsInstance(tm, TestModel) @@ -81,22 +83,6 @@ def test_model_save_and_load(self): for cname in tm._columns.keys(): self.assertEqual(getattr(tm, cname), getattr(tm2, cname)) - def test_model_instantiation_save_and_load(self): - """ - Tests that models can be saved and retrieved, this time using the - natural model instantiation. - """ - tm = TestModel(count=8, text='123456789') - # Tests that values are available on instantiation. - self.assertIsNotNone(tm['id']) - self.assertEqual(tm.count, 8) - self.assertEqual(tm.text, '123456789') - tm.save() - tm2 = TestModel.objects(id=tm.id).first() - - for cname in tm._columns.keys(): - self.assertEqual(getattr(tm, cname), getattr(tm2, cname)) - def test_model_read_as_dict(self): """ Tests that columns of an instance can be read as a dict. @@ -482,49 +468,6 @@ def test_previous_value_tracking_on_instantiation(self): self.assertTrue(self.instance._values['count'].previous_value is None) self.assertTrue(self.instance.count is None) - def test_previous_value_tracking_on_instantiation_with_default(self): - - class TestDefaultValueTracking(Model): - id = columns.Integer(partition_key=True) - int1 = columns.Integer(default=123) - int2 = columns.Integer(default=456) - int3 = columns.Integer(default=lambda: random.randint(0, 1000)) - int4 = columns.Integer(default=lambda: random.randint(0, 1000)) - int5 = columns.Integer() - int6 = columns.Integer() - - instance = TestDefaultValueTracking( - id=1, - int1=9999, - int3=7777, - int5=5555) - - self.assertEqual(instance.id, 1) - self.assertEqual(instance.int1, 9999) - self.assertEqual(instance.int2, 456) - self.assertEqual(instance.int3, 7777) - self.assertIsNotNone(instance.int4) - self.assertIsInstance(instance.int4, int) - self.assertGreaterEqual(instance.int4, 0) - self.assertLessEqual(instance.int4, 1000) - self.assertEqual(instance.int5, 5555) - self.assertTrue(instance.int6 is None) - - # All previous values are unset as the object hasn't been persisted - # yet. - self.assertTrue(instance._values['id'].previous_value is None) - self.assertTrue(instance._values['int1'].previous_value is None) - self.assertTrue(instance._values['int2'].previous_value is None) - self.assertTrue(instance._values['int3'].previous_value is None) - self.assertTrue(instance._values['int4'].previous_value is None) - self.assertTrue(instance._values['int5'].previous_value is None) - self.assertTrue(instance._values['int6'].previous_value is None) - - # All explicitely set columns, and those with default values are - # flagged has changed. - self.assertTrue(set(instance.get_changed_columns()) == set([ - 'id', 'int1', 'int2', 'int3', 'int4', 'int5'])) - def test_save_to_none(self): """ Test update of column value of None with save() function. diff --git a/tests/integration/cqlengine/model/test_updates.py b/tests/integration/cqlengine/model/test_updates.py index bc39d142cf..242bffe12f 100644 --- a/tests/integration/cqlengine/model/test_updates.py +++ b/tests/integration/cqlengine/model/test_updates.py @@ -79,8 +79,8 @@ def test_update_values(self): self.assertEqual(m2.count, m1.count) self.assertEqual(m2.text, m0.text) - def test_noop_model_direct_update(self): - """ Tests that calling update on a model with no changes will do nothing. """ + def test_noop_model_update(self): + """ tests that calling update on a model with no changes will do nothing. """ m0 = TestUpdateModel.create(count=5, text='monkey') with patch.object(self.session, 'execute') as execute: @@ -91,38 +91,6 @@ def test_noop_model_direct_update(self): m0.update(count=5) assert execute.call_count == 0 - with self.assertRaises(ValidationError): - m0.update(partition=m0.partition) - - with self.assertRaises(ValidationError): - m0.update(cluster=m0.cluster) - - def test_noop_model_assignation_update(self): - """ Tests that assigning the same value on a model will do nothing. """ - # Create object and fetch it back to eliminate any hidden variable - # cache effect. - m0 = TestUpdateModel.create(count=5, text='monkey') - m1 = TestUpdateModel.get(partition=m0.partition, cluster=m0.cluster) - - with patch.object(self.session, 'execute') as execute: - m1.save() - assert execute.call_count == 0 - - with patch.object(self.session, 'execute') as execute: - m1.count = 5 - m1.save() - assert execute.call_count == 0 - - with patch.object(self.session, 'execute') as execute: - m1.partition = m0.partition - m1.save() - assert execute.call_count == 0 - - with patch.object(self.session, 'execute') as execute: - m1.cluster = m0.cluster - m1.save() - assert execute.call_count == 0 - def test_invalid_update_kwarg(self): """ tests that passing in a kwarg to the update method that isn't a column will fail """ m0 = TestUpdateModel.create(count=5, text='monkey') diff --git a/tests/integration/cqlengine/query/test_named.py b/tests/integration/cqlengine/query/test_named.py index 55129cb985..9cddbece17 100644 --- a/tests/integration/cqlengine/query/test_named.py +++ b/tests/integration/cqlengine/query/test_named.py @@ -342,7 +342,7 @@ def test_named_table_with_mv(self): # Populate the base table with data prepared_insert = self.session.prepare("""INSERT INTO {0}.scores (user, game, year, month, day, score) VALUES (?, ?, ? ,? ,?, ?)""".format(ks)) - parameters = (('pcmanus', 'Coup', 2015, 5, 1, 4000), + parameters = {('pcmanus', 'Coup', 2015, 5, 1, 4000), ('jbellis', 'Coup', 2015, 5, 3, 1750), ('yukim', 'Coup', 2015, 5, 3, 2250), ('tjake', 'Coup', 2015, 5, 3, 500), @@ -353,7 +353,7 @@ def test_named_table_with_mv(self): ('jbellis', 'Coup', 2015, 6, 20, 3500), ('jbellis', 'Checkers', 2015, 6, 20, 1200), ('jbellis', 'Chess', 2015, 6, 21, 3500), - ('pcmanus', 'Chess', 2015, 1, 25, 3200)) + ('pcmanus', 'Chess', 2015, 1, 25, 3200)} prepared_insert.consistency_level = ConsistencyLevel.ALL execute_concurrent_with_args(self.session, prepared_insert, parameters) diff --git a/tests/integration/cqlengine/query/test_queryoperators.py b/tests/integration/cqlengine/query/test_queryoperators.py index 055e8f3db2..c2a2a74206 100644 --- a/tests/integration/cqlengine/query/test_queryoperators.py +++ b/tests/integration/cqlengine/query/test_queryoperators.py @@ -72,7 +72,7 @@ def tearDown(self): super(TestTokenFunction, self).tearDown() drop_table(TokenTestModel) - @execute_count(15) + @execute_count(14) def test_token_function(self): """ Tests that token functions work properly """ assert TokenTestModel.objects().count() == 0 @@ -91,10 +91,6 @@ def test_token_function(self): assert len(seen_keys) == 10 assert all([i in seen_keys for i in range(10)]) - # pk__token equality - r = TokenTestModel.objects(pk__token=functions.Token(last_token)) - self.assertEqual(len(r), 1) - def test_compound_pk_token_function(self): class TestModel(Model): diff --git a/tests/integration/cqlengine/query/test_queryset.py b/tests/integration/cqlengine/query/test_queryset.py index ea303373b8..0776d67943 100644 --- a/tests/integration/cqlengine/query/test_queryset.py +++ b/tests/integration/cqlengine/query/test_queryset.py @@ -268,7 +268,6 @@ def test_defining_defer_fields(self): @since 3.5 @jira_ticket PYTHON-560 - @jira_ticket PYTHON-599 @expected_result deferred fields should not be returned @test_category object_mapper @@ -301,10 +300,6 @@ def test_defining_defer_fields(self): q = TestModel.objects.filter(test_id=0) self.assertEqual(q._select_fields(), ['attempt_id', 'description', 'expected_result', 'test_result']) - # when all fields are defered, it fallbacks select the partition keys - q = TestModel.objects.defer(['test_id', 'attempt_id', 'description', 'expected_result', 'test_result']) - self.assertEqual(q._select_fields(), ['test_id']) - class BaseQuerySetUsage(BaseCassEngTestCase): @@ -852,12 +847,16 @@ def test_tzaware_datetime_support(self): def test_success_case(self): """ Test that the min and max time uuid functions work as expected """ pk = uuid4() - startpoint = datetime.utcnow() - TimeUUIDQueryModel.create(partition=pk, time=uuid_from_time(startpoint + timedelta(seconds=1)), data='1') - TimeUUIDQueryModel.create(partition=pk, time=uuid_from_time(startpoint + timedelta(seconds=2)), data='2') - midpoint = startpoint + timedelta(seconds=3) - TimeUUIDQueryModel.create(partition=pk, time=uuid_from_time(startpoint + timedelta(seconds=4)), data='3') - TimeUUIDQueryModel.create(partition=pk, time=uuid_from_time(startpoint + timedelta(seconds=5)), data='4') + TimeUUIDQueryModel.create(partition=pk, time=uuid1(), data='1') + time.sleep(0.2) + TimeUUIDQueryModel.create(partition=pk, time=uuid1(), data='2') + time.sleep(0.2) + midpoint = datetime.utcnow() + time.sleep(0.2) + TimeUUIDQueryModel.create(partition=pk, time=uuid1(), data='3') + time.sleep(0.2) + TimeUUIDQueryModel.create(partition=pk, time=uuid1(), data='4') + time.sleep(0.2) # test kwarg filtering q = TimeUUIDQueryModel.filter(partition=pk, time__lte=functions.MaxTimeUUID(midpoint)) @@ -895,6 +894,7 @@ def test_success_case(self): class TestInOperator(BaseQuerySetUsage): + @execute_count(1) def test_kwarg_success_case(self): """ Tests the in operator works with the kwarg query method """ @@ -907,51 +907,6 @@ def test_query_expression_success_case(self): q = TestModel.filter(TestModel.test_id.in_([0, 1])) assert q.count() == 8 - @execute_count(5) - def test_bool(self): - """ - Adding coverage to cqlengine for bool types. - - @since 3.6 - @jira_ticket PYTHON-596 - @expected_result bool results should be filtered appropriately - - @test_category object_mapper - """ - class bool_model(Model): - k = columns.Integer(primary_key=True) - b = columns.Boolean(primary_key=True) - v = columns.Integer(default=3) - sync_table(bool_model) - - bool_model.create(k=0, b=True) - bool_model.create(k=0, b=False) - self.assertEqual(len(bool_model.objects.all()), 2) - self.assertEqual(len(bool_model.objects.filter(k=0, b=True)), 1) - self.assertEqual(len(bool_model.objects.filter(k=0, b=False)), 1) - - @execute_count(3) - def test_bool_filter(self): - """ - Test to ensure that we don't translate boolean objects to String unnecessarily in filter clauses - - @since 3.6 - @jira_ticket PYTHON-596 - @expected_result We should not receive a server error - - @test_category object_mapper - """ - class bool_model2(Model): - k = columns.Boolean(primary_key=True) - b = columns.Integer(primary_key=True) - v = columns.Text() - drop_table(bool_model2) - sync_table(bool_model2) - - bool_model2.create(k=True, b=1, v='a') - bool_model2.create(k=False, b=1, v='b') - self.assertEqual(len(list(bool_model2.objects(k__in=(True, False)))), 2) - @greaterthancass20 class TestContainsOperator(BaseQuerySetUsage): @@ -1398,3 +1353,5 @@ def test_defaultFetchSize(self): smiths = list(People2.filter(last_name="Smith")) self.assertEqual(len(smiths), 5) self.assertTrue(smiths[0].last_name is not None) + + diff --git a/tests/integration/cqlengine/test_context_query.py b/tests/integration/cqlengine/test_context_query.py deleted file mode 100644 index b3941319e9..0000000000 --- a/tests/integration/cqlengine/test_context_query.py +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright 2013-2016 DataStax, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from cassandra.cqlengine import columns -from cassandra.cqlengine.management import drop_keyspace, sync_table, create_keyspace_simple -from cassandra.cqlengine.models import Model -from cassandra.cqlengine.query import ContextQuery -from tests.integration.cqlengine.base import BaseCassEngTestCase - - -class TestModel(Model): - - __keyspace__ = 'ks1' - - partition = columns.Integer(primary_key=True) - cluster = columns.Integer(primary_key=True) - count = columns.Integer() - text = columns.Text() - - -class ContextQueryTests(BaseCassEngTestCase): - - KEYSPACES = ('ks1', 'ks2', 'ks3', 'ks4') - - @classmethod - def setUpClass(cls): - super(ContextQueryTests, cls).setUpClass() - for ks in cls.KEYSPACES: - create_keyspace_simple(ks, 1) - sync_table(TestModel, keyspaces=cls.KEYSPACES) - - @classmethod - def tearDownClass(cls): - super(ContextQueryTests, cls).tearDownClass() - for ks in cls.KEYSPACES: - drop_keyspace(ks) - - def setUp(self): - super(ContextQueryTests, self).setUp() - for ks in self.KEYSPACES: - with ContextQuery(TestModel, keyspace=ks) as tm: - for obj in tm.all(): - obj.delete() - - def test_context_manager(self): - """ - Validates that when a context query is constructed that the - keyspace of the returned model is toggled appropriately - - @since 3.6 - @jira_ticket PYTHON-598 - @expected_result default keyspace should be used - - @test_category query - """ - # model keyspace write/read - for ks in self.KEYSPACES: - with ContextQuery(TestModel, keyspace=ks) as tm: - self.assertEqual(tm.__keyspace__, ks) - - self.assertEqual(TestModel._get_keyspace(), 'ks1') - - def test_default_keyspace(self): - """ - Tests the use of context queries with the default model keyspsace - - @since 3.6 - @jira_ticket PYTHON-598 - @expected_result default keyspace should be used - - @test_category query - """ - # model keyspace write/read - for i in range(5): - TestModel.objects.create(partition=i, cluster=i) - - with ContextQuery(TestModel) as tm: - self.assertEqual(5, len(tm.objects.all())) - - with ContextQuery(TestModel, keyspace='ks1') as tm: - self.assertEqual(5, len(tm.objects.all())) - - for ks in self.KEYSPACES[1:]: - with ContextQuery(TestModel, keyspace=ks) as tm: - self.assertEqual(0, len(tm.objects.all())) - - def test_context_keyspace(self): - """ - Tests the use of context queries with non default keyspaces - - @since 3.6 - @jira_ticket PYTHON-598 - @expected_result queries should be routed to appropriate keyspaces - - @test_category query - """ - for i in range(5): - with ContextQuery(TestModel, keyspace='ks4') as tm: - tm.objects.create(partition=i, cluster=i) - - with ContextQuery(TestModel, keyspace='ks4') as tm: - self.assertEqual(5, len(tm.objects.all())) - - self.assertEqual(0, len(TestModel.objects.all())) - - for ks in self.KEYSPACES[:2]: - with ContextQuery(TestModel, keyspace=ks) as tm: - self.assertEqual(0, len(tm.objects.all())) - - # simple data update - with ContextQuery(TestModel, keyspace='ks4') as tm: - obj = tm.objects.get(partition=1) - obj.update(count=42) - - self.assertEqual(42, tm.objects.get(partition=1).count) - diff --git a/tests/integration/cqlengine/test_lwt_conditional.py b/tests/integration/cqlengine/test_lwt_conditional.py index 8395154c34..d273df9cc0 100644 --- a/tests/integration/cqlengine/test_lwt_conditional.py +++ b/tests/integration/cqlengine/test_lwt_conditional.py @@ -234,18 +234,3 @@ def test_update_to_none(self): self.assertIsNotNone(TestConditionalModel.objects(id=t.id).first().text) TestConditionalModel.objects(id=t.id).iff(count=5).update(text=None) self.assertIsNone(TestConditionalModel.objects(id=t.id).first().text) - - def test_column_delete_after_update(self): - # DML path - t = TestConditionalModel.create(text='something', count=5) - t.iff(count=5).update(text=None, count=6) - - self.assertIsNone(t.text) - self.assertEqual(t.count, 6) - - # QuerySet path - t = TestConditionalModel.create(text='something', count=5) - TestConditionalModel.objects(id=t.id).iff(count=5).update(text=None, count=6) - - self.assertIsNone(TestConditionalModel.objects(id=t.id).first().text) - self.assertEqual(TestConditionalModel.objects(id=t.id).first().count, 6) diff --git a/tests/integration/cqlengine/test_ttl.py b/tests/integration/cqlengine/test_ttl.py index 3e16292781..ba2c1e0935 100644 --- a/tests/integration/cqlengine/test_ttl.py +++ b/tests/integration/cqlengine/test_ttl.py @@ -18,7 +18,6 @@ except ImportError: import unittest # noqa -from cassandra import InvalidRequest from cassandra.cqlengine.management import sync_table, drop_table from tests.integration.cqlengine.base import BaseCassEngTestCase from cassandra.cqlengine.models import Model @@ -159,16 +158,6 @@ def test_ttl_included_with_blind_update(self): @unittest.skipIf(CASSANDRA_VERSION < '2.0', "default_time_to_Live was introduce in C* 2.0, currently running {0}".format(CASSANDRA_VERSION)) class TTLDefaultTest(BaseDefaultTTLTest): - def get_default_ttl(self, table_name): - session = get_session() - try: - default_ttl = session.execute("SELECT default_time_to_live FROM system_schema.tables " - "WHERE keyspace_name = 'cqlengine_test' AND table_name = '{0}'".format(table_name)) - except InvalidRequest: - default_ttl = session.execute("SELECT default_time_to_live FROM system.schema_columnfamilies " - "WHERE keyspace_name = 'cqlengine_test' AND columnfamily_name = '{0}'".format(table_name)) - return default_ttl[0]['default_time_to_live'] - def test_default_ttl_not_set(self): session = get_session() @@ -177,9 +166,6 @@ def test_default_ttl_not_set(self): self.assertIsNone(o._ttl) - default_ttl = self.get_default_ttl('test_ttlmodel') - self.assertEqual(default_ttl, 0) - with mock.patch.object(session, 'execute') as m: TestTTLModel.objects(id=tid).update(text="aligators") @@ -188,44 +174,23 @@ def test_default_ttl_not_set(self): def test_default_ttl_set(self): session = get_session() - o = TestDefaultTTLModel.create(text="some text on ttl") tid = o.id - # Should not be set, it's handled by Cassandra - self.assertIsNone(o._ttl) - - default_ttl = self.get_default_ttl('test_default_ttlmodel') - self.assertEqual(default_ttl, 20) + self.assertEqual(o._ttl, TestDefaultTTLModel.__default_ttl__) with mock.patch.object(session, 'execute') as m: - TestTTLModel.objects(id=tid).update(text="aligators expired") + TestDefaultTTLModel.objects(id=tid).update(text="aligators expired") - # Should not be set either query = m.call_args[0][0].query_string - self.assertNotIn("USING TTL", query) - - def test_default_ttl_modify(self): - session = get_session() - - default_ttl = self.get_default_ttl('test_default_ttlmodel') - self.assertEqual(default_ttl, 20) - - TestDefaultTTLModel.__options__ = {'default_time_to_live': 10} - sync_table(TestDefaultTTLModel) - - default_ttl = self.get_default_ttl('test_default_ttlmodel') - self.assertEqual(default_ttl, 10) - - # Restore default TTL - TestDefaultTTLModel.__options__ = {'default_time_to_live': 20} - sync_table(TestDefaultTTLModel) + self.assertIn("USING TTL", query) def test_override_default_ttl(self): session = get_session() o = TestDefaultTTLModel.create(text="some text on ttl") tid = o.id + self.assertEqual(o._ttl, TestDefaultTTLModel.__default_ttl__) o.ttl(3600) self.assertEqual(o._ttl, 3600) diff --git a/tests/integration/long/test_ssl.py b/tests/integration/long/test_ssl.py index 99875afe48..9ae84f38d6 100644 --- a/tests/integration/long/test_ssl.py +++ b/tests/integration/long/test_ssl.py @@ -17,7 +17,7 @@ except ImportError: import unittest -import os, sys, traceback, logging, ssl, time +import os, sys, traceback, logging, ssl from cassandra.cluster import Cluster, NoHostAvailable from cassandra import ConsistencyLevel from cassandra.query import SimpleStatement @@ -86,7 +86,7 @@ def validate_ssl_options(ssl_options): raise RuntimeError("Failed to connect to SSL cluster after 5 attempts") try: cluster = Cluster(protocol_version=PROTOCOL_VERSION, ssl_options=ssl_options) - session = cluster.connect(wait_for_all_pools=True) + session = cluster.connect() break except Exception: ex_type, ex, tb = sys.exc_info() @@ -132,47 +132,11 @@ def test_can_connect_with_ssl_ca(self): @test_category connection:ssl """ - # find absolute path to client CA_CERTS - abs_path_ca_cert_path = os.path.abspath(CLIENT_CA_CERTS) - ssl_options = {'ca_certs': abs_path_ca_cert_path,'ssl_version': ssl.PROTOCOL_TLSv1} - validate_ssl_options(ssl_options=ssl_options) - - def test_can_connect_with_ssl_long_running(self): - """ - Test to validate that long running ssl connections continue to function past thier timeout window - - @since 3.6.0 - @jira_ticket PYTHON-600 - @expected_result The client can connect via SSL and preform some basic operations over a period of longer then a minute - - @test_category connection:ssl - """ - # find absolute path to client CA_CERTS abs_path_ca_cert_path = os.path.abspath(CLIENT_CA_CERTS) ssl_options = {'ca_certs': abs_path_ca_cert_path, 'ssl_version': ssl.PROTOCOL_TLSv1} - tries = 0 - while True: - if tries > 5: - raise RuntimeError("Failed to connect to SSL cluster after 5 attempts") - try: - cluster = Cluster(protocol_version=PROTOCOL_VERSION, ssl_options=ssl_options) - session = cluster.connect(wait_for_all_pools=True) - break - except Exception: - ex_type, ex, tb = sys.exc_info() - log.warn("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) - del tb - tries += 1 - - # attempt a few simple commands. - - for i in range(8): - rs = session.execute("SELECT * FROM system.local") - time.sleep(10) - - cluster.shutdown() + validate_ssl_options(ssl_options=ssl_options) def test_can_connect_with_ssl_ca_host_match(self): """ diff --git a/tests/integration/standard/test_cluster.py b/tests/integration/standard/test_cluster.py index 4c8339a2cf..62244b93f3 100644 --- a/tests/integration/standard/test_cluster.py +++ b/tests/integration/standard/test_cluster.py @@ -33,8 +33,7 @@ from cassandra.protocol import MAX_SUPPORTED_VERSION from cassandra.query import SimpleStatement, TraceUnavailable, tuple_factory -from tests.integration import use_singledc, PROTOCOL_VERSION, get_server_versions, get_node, CASSANDRA_VERSION, execute_until_pass, execute_with_long_wait_retry, get_node,\ - MockLoggingHandler, get_unsupported_lower_protocol, get_unsupported_upper_protocol +from tests.integration import use_singledc, PROTOCOL_VERSION, get_server_versions, get_node, CASSANDRA_VERSION, execute_until_pass, execute_with_long_wait_retry, get_node, MockLoggingHandler from tests.integration.util import assert_quiescent_pool_state @@ -42,40 +41,8 @@ def setup_module(): use_singledc() -class IgnoredHostPolicy(RoundRobinPolicy): - - def __init__(self, ignored_hosts): - self.ignored_hosts = ignored_hosts - RoundRobinPolicy.__init__(self) - - def distance(self, host): - if(str(host) in self.ignored_hosts): - return HostDistance.IGNORED - else: - return HostDistance.LOCAL - - class ClusterTests(unittest.TestCase): - def test_ignored_host_up(self): - """ - Test to ensure that is_up is not set by default on ignored hosts - - @since 3.6 - @jira_ticket PYTHON-551 - @expected_result ignored hosts should have None set for is_up - - @test_category connection - """ - ingored_host_policy = IgnoredHostPolicy(["127.0.0.2", "127.0.0.3"]) - cluster = Cluster(protocol_version=PROTOCOL_VERSION, load_balancing_policy=ingored_host_policy) - session = cluster.connect() - for host in cluster.metadata.all_hosts(): - if str(host) == "127.0.0.1": - self.assertTrue(host.is_up) - else: - self.assertIsNone(host.is_up) - def test_host_resolution(self): """ Test to insure A records are resolved appropriately. @@ -100,11 +67,11 @@ def test_host_duplication(self): @test_category connection """ cluster = Cluster(contact_points=["localhost", "127.0.0.1", "localhost", "localhost", "localhost"], protocol_version=PROTOCOL_VERSION, connect_timeout=1) - cluster.connect(wait_for_all_pools=True) + cluster.connect() self.assertEqual(len(cluster.metadata.all_hosts()), 3) cluster.shutdown() cluster = Cluster(contact_points=["127.0.0.1", "localhost"], protocol_version=PROTOCOL_VERSION, connect_timeout=1) - cluster.connect(wait_for_all_pools=True) + cluster.connect() self.assertEqual(len(cluster.metadata.all_hosts()), 3) cluster.shutdown() @@ -208,42 +175,6 @@ def test_protocol_negotiation(self): cluster.shutdown() - def test_invalid_protocol_negotation(self): - """ - Test for protocol negotiation when explicit versions are set - - If an explicit protocol version that is not compatible with the server version is set - an exception should be thrown. It should not attempt to negotiate - - for reference supported protocol version to server versions is as follows/ - - 1.2 -> 1 - 2.0 -> 2, 1 - 2.1 -> 3, 2, 1 - 2.2 -> 4, 3, 2, 1 - 3.X -> 4, 3 - - @since 3.6.0 - @jira_ticket PYTHON-537 - @expected_result downgrading should not be allowed when explicit protocol versions are set. - - @test_category connection - """ - - upper_bound = get_unsupported_upper_protocol() - if upper_bound is not None: - cluster = Cluster(protocol_version=upper_bound) - with self.assertRaises(NoHostAvailable): - cluster.connect() - cluster.shutdown() - - lower_bound = get_unsupported_lower_protocol() - if lower_bound is not None: - cluster = Cluster(protocol_version=lower_bound) - with self.assertRaises(NoHostAvailable): - cluster.connect() - cluster.shutdown() - def test_connect_on_keyspace(self): """ Ensure clusters that connect on a keyspace, do @@ -585,14 +516,14 @@ def test_idle_heartbeat(self): cluster = Cluster(protocol_version=PROTOCOL_VERSION, idle_heartbeat_interval=interval) if PROTOCOL_VERSION < 3: cluster.set_core_connections_per_host(HostDistance.LOCAL, 1) - session = cluster.connect(wait_for_all_pools=True) + session = cluster.connect() # This test relies on impl details of connection req id management to see if heartbeats # are being sent. May need update if impl is changed connection_request_ids = {} for h in cluster.get_connection_holders(): for c in h.get_connections(): - # make sure none are idle (should have startup messages + # make sure none are idle (should have startup messages) self.assertFalse(c.is_idle) with c.lock: connection_request_ids[id(c)] = deque(c.request_ids) # copy of request ids @@ -627,7 +558,7 @@ def test_idle_heartbeat(self): self.assertEqual(len(holders), len(cluster.metadata.all_hosts()) + 1) # hosts pools, 1 for cc # include additional sessions - session2 = cluster.connect(wait_for_all_pools=True) + session2 = cluster.connect() holders = cluster.get_connection_holders() self.assertIn(cluster.control_connection, holders) @@ -700,7 +631,7 @@ def test_profile_load_balancing(self): query = "select release_version from system.local" node1 = ExecutionProfile(load_balancing_policy=WhiteListRoundRobinPolicy(['127.0.0.1'])) with Cluster(execution_profiles={'node1': node1}) as cluster: - session = cluster.connect(wait_for_all_pools=True) + session = cluster.connect() # default is DCA RR for all hosts expected_hosts = set(cluster.metadata.all_hosts()) @@ -757,7 +688,7 @@ def test_profile_lb_swap(self): rr2 = ExecutionProfile(load_balancing_policy=RoundRobinPolicy()) exec_profiles = {'rr1': rr1, 'rr2': rr2} with Cluster(execution_profiles=exec_profiles) as cluster: - session = cluster.connect(wait_for_all_pools=True) + session = cluster.connect() # default is DCA RR for all hosts expected_hosts = set(cluster.metadata.all_hosts()) @@ -849,7 +780,7 @@ def test_profile_pool_management(self): node1 = ExecutionProfile(load_balancing_policy=WhiteListRoundRobinPolicy(['127.0.0.1'])) node2 = ExecutionProfile(load_balancing_policy=WhiteListRoundRobinPolicy(['127.0.0.2'])) with Cluster(execution_profiles={EXEC_PROFILE_DEFAULT: node1, 'node2': node2}) as cluster: - session = cluster.connect(wait_for_all_pools=True) + session = cluster.connect() pools = session.get_pool_state() # there are more hosts, but we connected to the ones in the lbp aggregate self.assertGreater(len(cluster.metadata.all_hosts()), 2) @@ -874,7 +805,7 @@ def test_add_profile_timeout(self): node1 = ExecutionProfile(load_balancing_policy=WhiteListRoundRobinPolicy(['127.0.0.1'])) with Cluster(execution_profiles={EXEC_PROFILE_DEFAULT: node1}) as cluster: - session = cluster.connect(wait_for_all_pools=True) + session = cluster.connect() pools = session.get_pool_state() self.assertGreater(len(cluster.metadata.all_hosts()), 2) self.assertEqual(set(h.address for h in pools), set(('127.0.0.1',))) diff --git a/tests/integration/standard/test_connection.py b/tests/integration/standard/test_connection.py index 69566c80ad..2d07b92038 100644 --- a/tests/integration/standard/test_connection.py +++ b/tests/integration/standard/test_connection.py @@ -99,7 +99,7 @@ class HeartbeatTest(unittest.TestCase): def setUp(self): self.cluster = Cluster(protocol_version=PROTOCOL_VERSION, idle_heartbeat_interval=1) - self.session = self.cluster.connect(wait_for_all_pools=True) + self.session = self.cluster.connect() def tearDown(self): self.cluster.shutdown() diff --git a/tests/integration/standard/test_custom_protocol_handler.py b/tests/integration/standard/test_custom_protocol_handler.py index c6818f7f4b..63a8380902 100644 --- a/tests/integration/standard/test_custom_protocol_handler.py +++ b/tests/integration/standard/test_custom_protocol_handler.py @@ -126,7 +126,7 @@ class CustomResultMessageRaw(ResultMessage): type_codes = my_type_codes @classmethod - def recv_results_rows(cls, f, protocol_version, user_type_map, result_metadata): + def recv_results_rows(cls, f, protocol_version, user_type_map): paging_state, column_metadata = cls.recv_results_metadata(f, user_type_map) rowcount = read_int(f) rows = [cls.recv_row(f, len(column_metadata)) for _ in range(rowcount)] @@ -155,7 +155,7 @@ class CustomResultMessageTracked(ResultMessage): checked_rev_row_set = set() @classmethod - def recv_results_rows(cls, f, protocol_version, user_type_map, result_metadata): + def recv_results_rows(cls, f, protocol_version, user_type_map): paging_state, column_metadata = cls.recv_results_metadata(f, user_type_map) rowcount = read_int(f) rows = [cls.recv_row(f, len(column_metadata)) for _ in range(rowcount)] diff --git a/tests/integration/standard/test_cython_protocol_handlers.py b/tests/integration/standard/test_cython_protocol_handlers.py index 7dc3db300e..3560709faa 100644 --- a/tests/integration/standard/test_cython_protocol_handlers.py +++ b/tests/integration/standard/test_cython_protocol_handlers.py @@ -7,12 +7,10 @@ except ImportError: import unittest -from itertools import count - +from cassandra import DriverException, Timeout, AlreadyExists from cassandra.query import tuple_factory from cassandra.cluster import Cluster, NoHostAvailable -from cassandra.concurrent import execute_concurrent_with_args -from cassandra.protocol import ProtocolHandler, LazyProtocolHandler, NumpyProtocolHandler +from cassandra.protocol import ProtocolHandler, LazyProtocolHandler, NumpyProtocolHandler, ConfigurationException from cassandra.cython_deps import HAVE_CYTHON, HAVE_NUMPY from tests.integration import use_singledc, PROTOCOL_VERSION, notprotocolv1, drop_keyspace_shutdown_cluster, VERIFY_CYTHON, BasicSharedKeyspaceUnitTestCase, execute_with_retry_tolerant, greaterthancass21 from tests.integration.datatype_utils import update_datatypes @@ -209,49 +207,66 @@ def verify_iterator_data(assertEqual, results): class NumpyNullTest(BasicSharedKeyspaceUnitTestCase): + # A dictionary containing table key to type. + # Boolean dictates whether or not the type can be deserialized with null value + NUMPY_TYPES = {"v1": ('bigint', False), + "v2": ('double', False), + "v3": ('float', False), + "v4": ('int', False), + "v5": ('smallint', False), + "v6": ("ascii", True), + "v7": ("blob", True), + "v8": ("boolean", True), + "v9": ("decimal", True), + "v10": ("inet", True), + "v11": ("text", True), + "v12": ("timestamp", True), + "v13": ("timeuuid", True), + "v14": ("uuid", True), + "v15": ("varchar", True), + "v16": ("varint", True), + } + + def setUp(self): + self.session.client_protocol_handler = NumpyProtocolHandler + self.session.row_factory = tuple_factory + @numpytest @greaterthancass21 def test_null_types(self): """ Test to validate that the numpy protocol handler can deal with null values. @since 3.3.0 - - updated 3.6.0: now numeric types used masked array @jira_ticket PYTHON-550 @expected_result Numpy can handle non mapped types' null values. @test_category data_types:serialization """ - s = self.session - s.row_factory = tuple_factory - s.client_protocol_handler = NumpyProtocolHandler - - table = "%s.%s" % (self.keyspace_name, self.function_table_name) - create_table_with_all_types(table, s, 10) - - begin_unset = max(s.execute('select primkey from %s' % (table,))[0]['primkey']) + 1 - keys_null = range(begin_unset, begin_unset + 10) - - # scatter some emptry rows in here - insert = "insert into %s (primkey) values (%%s)" % (table,) - execute_concurrent_with_args(s, insert, ((k,) for k in keys_null)) - - result = s.execute("select * from %s" % (table,))[0] - - from numpy.ma import masked, MaskedArray - result_keys = result.pop('primkey') - mapped_index = [v[1] for v in sorted(zip(result_keys, count()))] - - had_masked = had_none = False - for col_array in result.values(): - # these have to be different branches (as opposed to comparing against an 'unset value') - # because None and `masked` have different identity and equals semantics - if isinstance(col_array, MaskedArray): - had_masked = True - [self.assertIsNot(col_array[i], masked) for i in mapped_index[:begin_unset]] - [self.assertIs(col_array[i], masked) for i in mapped_index[begin_unset:]] + + self.create_table_of_types() + self.session.execute("INSERT INTO {0}.{1} (k) VALUES (1)".format(self.keyspace_name, self.function_table_name)) + self.validate_types() + + def create_table_of_types(self): + """ + Builds a table containing all the numpy types + """ + base_ddl = '''CREATE TABLE {0}.{1} (k int PRIMARY KEY'''.format(self.keyspace_name, self.function_table_name, type) + for key, value in NumpyNullTest.NUMPY_TYPES.items(): + base_ddl = base_ddl+", {0} {1}".format(key, value[0]) + base_ddl = base_ddl+")" + execute_with_retry_tolerant(self.session, base_ddl, (DriverException, NoHostAvailable, Timeout), (ConfigurationException, AlreadyExists)) + + def validate_types(self): + """ + Selects each type from the table and expects either an exception or None depending on type + """ + for key, value in NumpyNullTest.NUMPY_TYPES.items(): + select = "SELECT {0} from {1}.{2}".format(key,self.keyspace_name, self.function_table_name) + if value[1]: + rs = execute_with_retry_tolerant(self.session, select, (NoHostAvailable), ()) + self.assertEqual(rs[0].get('v1'), None) else: - had_none = True - [self.assertIsNotNone(col_array[i]) for i in mapped_index[:begin_unset]] - [self.assertIsNone(col_array[i]) for i in mapped_index[begin_unset:]] - self.assertTrue(had_masked) - self.assertTrue(had_none) + with self.assertRaises(ValueError): + execute_with_retry_tolerant(self.session, select, (NoHostAvailable), ()) + diff --git a/tests/integration/standard/test_metadata.py b/tests/integration/standard/test_metadata.py index 598dd83971..c317e50c3e 100644 --- a/tests/integration/standard/test_metadata.py +++ b/tests/integration/standard/test_metadata.py @@ -79,28 +79,6 @@ def test_host_release_version(self): self.assertTrue(host.release_version.startswith(CASSANDRA_VERSION)) -class MetaDataRemovalTest(unittest.TestCase): - - def setUp(self): - self.cluster = Cluster(protocol_version=PROTOCOL_VERSION, contact_points=['127.0.0.1','127.0.0.2', '127.0.0.3', '126.0.0.186']) - self.cluster.connect() - - def tearDown(self): - self.cluster.shutdown() - - def test_bad_contact_point(self): - """ - Checks to ensure that hosts that are not resolvable are excluded from the contact point list. - - @since 3.6 - @jira_ticket PYTHON-549 - @expected_result Invalid hosts on the contact list should be excluded - - @test_category metadata - """ - self.assertEqual(len(self.cluster.metadata.all_hosts()), 3) - - class SchemaMetadataTests(BasicSegregatedKeyspaceUnitTestCase): def test_schema_metadata_disable(self): @@ -1155,12 +1133,14 @@ def test_legacy_tables(self): CREATE TABLE legacy.composite_comp_with_col ( key blob, - column1 'org.apache.cassandra.db.marshal.DynamicCompositeType(b=>org.apache.cassandra.db.marshal.BytesType, s=>org.apache.cassandra.db.marshal.UTF8Type, t=>org.apache.cassandra.db.marshal.TimeUUIDType)', + b blob, + s text, + t timeuuid, "b@6869746d65776974686d75736963" blob, "b@6d616d6d616a616d6d61" blob, - PRIMARY KEY (key, column1) + PRIMARY KEY (key, b, s, t) ) WITH COMPACT STORAGE - AND CLUSTERING ORDER BY (column1 ASC) + AND CLUSTERING ORDER BY (b ASC, s ASC, t ASC) AND caching = '{"keys":"ALL", "rows_per_partition":"NONE"}' AND comment = 'Stores file meta data' AND compaction = {'min_threshold': '4', 'class': 'org.apache.cassandra.db.compaction.SizeTieredCompactionStrategy', 'max_threshold': '32'} @@ -1273,13 +1253,20 @@ def test_legacy_tables(self): AND read_repair_chance = 0.0 AND speculative_retry = 'NONE'; +/* +Warning: Table legacy.composite_comp_no_col omitted because it has constructs not compatible with CQL (was created via legacy API). + +Approximate structure, for reference: +(this should not be used to reproduce this schema) + CREATE TABLE legacy.composite_comp_no_col ( key blob, - column1 'org.apache.cassandra.db.marshal.DynamicCompositeType(b=>org.apache.cassandra.db.marshal.BytesType, s=>org.apache.cassandra.db.marshal.UTF8Type, t=>org.apache.cassandra.db.marshal.TimeUUIDType)', + column1 'org.apache.cassandra.db.marshal.DynamicCompositeType(org.apache.cassandra.db.marshal.BytesType, org.apache.cassandra.db.marshal.UTF8Type, org.apache.cassandra.db.marshal.TimeUUIDType)', + column2 timeuuid, value blob, - PRIMARY KEY (key, column1) + PRIMARY KEY (key, column1, column1, column2) ) WITH COMPACT STORAGE - AND CLUSTERING ORDER BY (column1 ASC) + AND CLUSTERING ORDER BY (column1 ASC, column1 ASC, column2 ASC) AND caching = '{"keys":"ALL", "rows_per_partition":"NONE"}' AND comment = 'Stores file meta data' AND compaction = {'min_threshold': '4', 'class': 'org.apache.cassandra.db.compaction.SizeTieredCompactionStrategy', 'max_threshold': '32'} @@ -1291,7 +1278,8 @@ def test_legacy_tables(self): AND memtable_flush_period_in_ms = 0 AND min_index_interval = 128 AND read_repair_chance = 0.0 - AND speculative_retry = 'NONE';""" + AND speculative_retry = 'NONE'; +*/""" ccm = get_cluster() ccm.run_cli(cli_script) @@ -2047,31 +2035,7 @@ def test_bad_user_aggregate(self): self.assertIn("/*\nWarning:", m.export_as_string()) -class DynamicCompositeTypeTest(BasicSharedKeyspaceUnitTestCase): - - def test_dct_alias(self): - """ - Tests to make sure DCT's have correct string formatting - - Constructs a DCT and check the format as generated. To insure it matches what is expected - - @since 3.6.0 - @jira_ticket PYTHON-579 - @expected_result DCT subtypes should always have fully qualified names - - @test_category metadata - """ - self.session.execute("CREATE TABLE {0}.{1} (" - "k int PRIMARY KEY," - "c1 'DynamicCompositeType(s => UTF8Type, i => Int32Type)'," - "c2 Text)".format(self.ks_name, self.function_table_name)) - dct_table = self.cluster.metadata.keyspaces.get(self.ks_name).tables.get(self.function_table_name) - - # Format can very slightly between versions, strip out whitespace for consistency sake - self.assertTrue("c1'org.apache.cassandra.db.marshal.DynamicCompositeType(s=>org.apache.cassandra.db.marshal.UTF8Type,i=>org.apache.cassandra.db.marshal.Int32Type)'" in dct_table.as_cql_query().replace(" ", "")) - - -class Materia3lizedViewMetadataTestSimple(BasicSharedKeyspaceUnitTestCase): +class MaterializedViewMetadataTestSimple(BasicSharedKeyspaceUnitTestCase): def setUp(self): if CASS_SERVER_VERSION < (3, 0): @@ -2219,37 +2183,37 @@ def test_create_view_metadata(self): self.assertIsNotNone(score_table.columns['score']) # Validate basic mv information - self.assertEqual(mv.keyspace_name, self.keyspace_name) - self.assertEqual(mv.name, "monthlyhigh") - self.assertEqual(mv.base_table_name, "scores") + self.assertEquals(mv.keyspace_name, self.keyspace_name) + self.assertEquals(mv.name, "monthlyhigh") + self.assertEquals(mv.base_table_name, "scores") self.assertFalse(mv.include_all_columns) # Validate that all columns are preset and correct mv_columns = list(mv.columns.values()) - self.assertEqual(len(mv_columns), 6) + self.assertEquals(len(mv_columns), 6) game_column = mv_columns[0] self.assertIsNotNone(game_column) - self.assertEqual(game_column.name, 'game') - self.assertEqual(game_column, mv.partition_key[0]) + self.assertEquals(game_column.name, 'game') + self.assertEquals(game_column, mv.partition_key[0]) year_column = mv_columns[1] self.assertIsNotNone(year_column) - self.assertEqual(year_column.name, 'year') - self.assertEqual(year_column, mv.partition_key[1]) + self.assertEquals(year_column.name, 'year') + self.assertEquals(year_column, mv.partition_key[1]) month_column = mv_columns[2] self.assertIsNotNone(month_column) - self.assertEqual(month_column.name, 'month') - self.assertEqual(month_column, mv.partition_key[2]) + self.assertEquals(month_column.name, 'month') + self.assertEquals(month_column, mv.partition_key[2]) def compare_columns(a, b, name): - self.assertEqual(a.name, name) - self.assertEqual(a.name, b.name) - self.assertEqual(a.table, b.table) - self.assertEqual(a.cql_type, b.cql_type) - self.assertEqual(a.is_static, b.is_static) - self.assertEqual(a.is_reversed, b.is_reversed) + self.assertEquals(a.name, name) + self.assertEquals(a.name, b.name) + self.assertEquals(a.table, b.table) + self.assertEquals(a.cql_type, b.cql_type) + self.assertEquals(a.is_static, b.is_static) + self.assertEquals(a.is_reversed, b.is_reversed) score_column = mv_columns[3] compare_columns(score_column, mv.clustering_key[0], 'score') @@ -2326,7 +2290,7 @@ def test_base_table_column_addition_mv(self): self.assertIn("fouls", mv_alltime.columns) mv_alltime_fouls_comumn = self.cluster.metadata.keyspaces[self.keyspace_name].views["alltimehigh"].columns['fouls'] - self.assertEqual(mv_alltime_fouls_comumn.cql_type, 'int') + self.assertEquals(mv_alltime_fouls_comumn.cql_type, 'int') def test_base_table_type_alter_mv(self): """ @@ -2367,7 +2331,7 @@ def test_base_table_type_alter_mv(self): self.assertEqual(len(self.cluster.metadata.keyspaces[self.keyspace_name].views), 1) score_column = self.cluster.metadata.keyspaces[self.keyspace_name].tables['scores'].columns['score'] - self.assertEqual(score_column.cql_type, 'blob') + self.assertEquals(score_column.cql_type, 'blob') # until CASSANDRA-9920+CASSANDRA-10500 MV updates are only available later with an async event for i in range(10): @@ -2376,7 +2340,7 @@ def test_base_table_type_alter_mv(self): break time.sleep(.2) - self.assertEqual(score_mv_column.cql_type, 'blob') + self.assertEquals(score_mv_column.cql_type, 'blob') def test_metadata_with_quoted_identifiers(self): """ @@ -2429,31 +2393,31 @@ def test_metadata_with_quoted_identifiers(self): self.assertIsNotNone(t1_table.columns['the Value']) # Validate basic mv information - self.assertEqual(mv.keyspace_name, self.keyspace_name) - self.assertEqual(mv.name, "mv1") - self.assertEqual(mv.base_table_name, "t1") + self.assertEquals(mv.keyspace_name, self.keyspace_name) + self.assertEquals(mv.name, "mv1") + self.assertEquals(mv.base_table_name, "t1") self.assertFalse(mv.include_all_columns) # Validate that all columns are preset and correct mv_columns = list(mv.columns.values()) - self.assertEqual(len(mv_columns), 3) + self.assertEquals(len(mv_columns), 3) theKey_column = mv_columns[0] self.assertIsNotNone(theKey_column) - self.assertEqual(theKey_column.name, 'theKey') - self.assertEqual(theKey_column, mv.partition_key[0]) + self.assertEquals(theKey_column.name, 'theKey') + self.assertEquals(theKey_column, mv.partition_key[0]) cluster_column = mv_columns[1] self.assertIsNotNone(cluster_column) - self.assertEqual(cluster_column.name, 'the;Clustering') - self.assertEqual(cluster_column.name, mv.clustering_key[0].name) - self.assertEqual(cluster_column.table, mv.clustering_key[0].table) - self.assertEqual(cluster_column.is_static, mv.clustering_key[0].is_static) - self.assertEqual(cluster_column.is_reversed, mv.clustering_key[0].is_reversed) + self.assertEquals(cluster_column.name, 'the;Clustering') + self.assertEquals(cluster_column.name, mv.clustering_key[0].name) + self.assertEquals(cluster_column.table, mv.clustering_key[0].table) + self.assertEquals(cluster_column.is_static, mv.clustering_key[0].is_static) + self.assertEquals(cluster_column.is_reversed, mv.clustering_key[0].is_reversed) value_column = mv_columns[2] self.assertIsNotNone(value_column) - self.assertEqual(value_column.name, 'the Value') + self.assertEquals(value_column.name, 'the Value') @dseonly diff --git a/tests/integration/standard/test_metrics.py b/tests/integration/standard/test_metrics.py index 18f35c15f1..13758b65ad 100644 --- a/tests/integration/standard/test_metrics.py +++ b/tests/integration/standard/test_metrics.py @@ -14,8 +14,6 @@ import time -from cassandra.policies import WhiteListRoundRobinPolicy, FallthroughRetryPolicy - try: import unittest2 as unittest except ImportError: @@ -26,8 +24,7 @@ from cassandra.cluster import Cluster, NoHostAvailable from tests.integration import get_cluster, get_node, use_singledc, PROTOCOL_VERSION, execute_until_pass -from greplin import scales -from tests.integration import BasicSharedKeyspaceUnitTestCaseWTable + def setup_module(): use_singledc() @@ -36,11 +33,8 @@ def setup_module(): class MetricsTests(unittest.TestCase): def setUp(self): - contact_point = ['127.0.0.2'] - self.cluster = Cluster(contact_points=contact_point, metrics_enabled=True, protocol_version=PROTOCOL_VERSION, - load_balancing_policy=WhiteListRoundRobinPolicy(contact_point), - default_retry_policy=FallthroughRetryPolicy()) - self.session = self.cluster.connect("test3rf", wait_for_all_pools=True) + self.cluster = Cluster(metrics_enabled=True, protocol_version=PROTOCOL_VERSION) + self.session = self.cluster.connect("test3rf") def tearDown(self): self.cluster.shutdown() @@ -50,6 +44,8 @@ def test_connection_error(self): Trigger and ensure connection_errors are counted Stop all node with the driver knowing about the "DOWN" states. """ + + # Test writes for i in range(0, 100): self.session.execute_async("INSERT INTO test (k, v) VALUES ({0}, {1})".format(i, i)) @@ -149,13 +145,13 @@ def test_unavailable(self): query = SimpleStatement("INSERT INTO test (k, v) VALUES (2, 2)", consistency_level=ConsistencyLevel.ALL) with self.assertRaises(Unavailable): self.session.execute(query) - self.assertEqual(self.cluster.metrics.stats.unavailables, 1) + self.assertEqual(2, self.cluster.metrics.stats.unavailables) # Test write query = SimpleStatement("SELECT * FROM test", consistency_level=ConsistencyLevel.ALL) with self.assertRaises(Unavailable): self.session.execute(query, timeout=None) - self.assertEqual(self.cluster.metrics.stats.unavailables, 2) + self.assertEqual(4, self.cluster.metrics.stats.unavailables) finally: get_node(1).start(wait_other_notice=True, wait_for_binary_proto=True) # Give some time for the cluster to come back up, for the next test @@ -174,102 +170,3 @@ def test_unavailable(self): # def test_retry(self): # # TODO: Look for ways to generate retries # pass - - -class MetricsNamespaceTest(BasicSharedKeyspaceUnitTestCaseWTable): - - def test_metrics_per_cluster(self): - """ - Test to validate that metrics can be scopped to invdividual clusters - @since 3.6.0 - @jira_ticket PYTHON-561 - @expected_result metrics should be scopped to a cluster level - - @test_category metrics - """ - - cluster2 = Cluster(metrics_enabled=True, protocol_version=PROTOCOL_VERSION, - default_retry_policy=FallthroughRetryPolicy()) - cluster2.connect(self.ks_name, wait_for_all_pools=True) - - query = SimpleStatement("SELECT * FROM {0}.{0}".format(self.ks_name), consistency_level=ConsistencyLevel.ALL) - self.session.execute(query) - - # Pause node so it shows as unreachable to coordinator - get_node(1).pause() - - try: - # Test write - query = SimpleStatement("INSERT INTO {0}.{0} (k, v) VALUES (2, 2)".format(self.ks_name), consistency_level=ConsistencyLevel.ALL) - with self.assertRaises(WriteTimeout): - self.session.execute(query, timeout=None) - finally: - get_node(1).resume() - - # Change the scales stats_name of the cluster2 - cluster2.metrics.set_stats_name('cluster2-metrics') - - stats_cluster1 = self.cluster.metrics.get_stats() - stats_cluster2 = cluster2.metrics.get_stats() - - # Test direct access to stats - self.assertEqual(1, self.cluster.metrics.stats.write_timeouts) - self.assertEqual(0, cluster2.metrics.stats.write_timeouts) - - # Test direct access to a child stats - self.assertNotEqual(0.0, self.cluster.metrics.request_timer['mean']) - self.assertEqual(0.0, cluster2.metrics.request_timer['mean']) - - # Test access via metrics.get_stats() - self.assertNotEqual(0.0, stats_cluster1['request_timer']['mean']) - self.assertEqual(0.0, stats_cluster2['request_timer']['mean']) - - # Test access by stats_name - self.assertEqual(0.0, scales.getStats()['cluster2-metrics']['request_timer']['mean']) - - cluster2.shutdown() - - def test_duplicate_metrics_per_cluster(self): - """ - Test to validate that cluster metrics names can't overlap. - @since 3.6.0 - @jira_ticket PYTHON-561 - @expected_result metric names should not be allowed to be same. - - @test_category metrics - """ - cluster2 = Cluster(metrics_enabled=True, protocol_version=PROTOCOL_VERSION, - default_retry_policy=FallthroughRetryPolicy()) - - cluster3 = Cluster(metrics_enabled=True, protocol_version=PROTOCOL_VERSION, - default_retry_policy=FallthroughRetryPolicy()) - - # Ensure duplicate metric names are not allowed - cluster2.metrics.set_stats_name("appcluster") - cluster2.metrics.set_stats_name("appcluster") - with self.assertRaises(ValueError): - cluster3.metrics.set_stats_name("appcluster") - cluster3.metrics.set_stats_name("devops") - - session2 = cluster2.connect(self.ks_name, wait_for_all_pools=True) - session3 = cluster3.connect(self.ks_name, wait_for_all_pools=True) - - # Basic validation that naming metrics doesn't impact their segration or accuracy - for i in range(10): - query = SimpleStatement("SELECT * FROM {0}.{0}".format(self.ks_name), consistency_level=ConsistencyLevel.ALL) - session2.execute(query) - - for i in range(5): - query = SimpleStatement("SELECT * FROM {0}.{0}".format(self.ks_name), consistency_level=ConsistencyLevel.ALL) - session3.execute(query) - - self.assertEqual(cluster2.metrics.get_stats()['request_timer']['count'], 10) - self.assertEqual(cluster3.metrics.get_stats()['request_timer']['count'], 5) - - # Check scales to ensure they are appropriately named - self.assertTrue("appcluster" in scales._Stats.stats.keys()) - self.assertTrue("devops" in scales._Stats.stats.keys()) - - - - diff --git a/tests/integration/standard/test_query.py b/tests/integration/standard/test_query.py index 719f2b1fc9..4bd742bc42 100644 --- a/tests/integration/standard/test_query.py +++ b/tests/integration/standard/test_query.py @@ -26,7 +26,7 @@ from cassandra.cluster import Cluster, NoHostAvailable from cassandra.policies import HostDistance, RoundRobinPolicy -from tests.integration import use_singledc, PROTOCOL_VERSION, BasicSharedKeyspaceUnitTestCase, get_server_versions, greaterthanprotocolv3, MockLoggingHandler, get_supported_protocol_versions +from tests.integration import use_singledc, PROTOCOL_VERSION, BasicSharedKeyspaceUnitTestCase, get_server_versions, greaterthanprotocolv3, MockLoggingHandler import time import re @@ -191,7 +191,7 @@ def test_incomplete_query_trace(self): self.assertTrue(self._wait_for_trace_to_populate(trace.trace_id)) # Delete trace duration from the session (this is what the driver polls for "complete") - delete_statement = SimpleStatement("DELETE duration FROM system_traces.sessions WHERE session_id = {0}".format(trace.trace_id), consistency_level=ConsistencyLevel.ALL) + delete_statement = SimpleStatement("DELETE duration FROM system_traces.sessions WHERE session_id = {}".format(trace.trace_id), consistency_level=ConsistencyLevel.ALL) self.session.execute(delete_statement) self.assertTrue(self._wait_for_trace_to_delete(trace.trace_id)) @@ -225,7 +225,7 @@ def _wait_for_trace_to_delete(self, trace_id): return count != retry_max def _is_trace_present(self, trace_id): - select_statement = SimpleStatement("SElECT duration FROM system_traces.sessions WHERE session_id = {0}".format(trace_id), consistency_level=ConsistencyLevel.ALL) + select_statement = SimpleStatement("SElECT duration FROM system_traces.sessions WHERE session_id = {}".format(trace_id), consistency_level=ConsistencyLevel.ALL) ssrs = self.session.execute(select_statement) if(ssrs[0].duration is None): return False @@ -356,39 +356,6 @@ def make_query_plan(self, working_keyspace=None, query=None): return list(self._live_hosts) -class PreparedStatementMetdataTest(unittest.TestCase): - - def test_prepared_metadata_generation(self): - """ - Test to validate that result metadata is appropriately populated across protocol version - - In protocol version 1 result metadata is retrieved everytime the statement is issued. In all - other protocol versions it's set once upon the prepare, then re-used. This test ensures that it manifests - it's self the same across multiple protocol versions. - - @since 3.6.0 - @jira_ticket PYTHON-71 - @expected_result result metadata is consistent. - """ - - base_line = None - for proto_version in get_supported_protocol_versions(): - cluster = Cluster(protocol_version=proto_version) - session = cluster.connect() - select_statement = session.prepare("SELECT * FROM system.local") - if proto_version == 1: - self.assertEqual(select_statement.result_metadata, None) - else: - self.assertNotEqual(select_statement.result_metadata, None) - future = session.execute_async(select_statement) - results = future.result() - if base_line is None: - base_line = results[0].__dict__.keys() - else: - self.assertEqual(base_line, results[0].__dict__.keys()) - cluster.shutdown() - - class PreparedStatementArgTest(unittest.TestCase): def test_prepare_on_all_hosts(self): @@ -914,73 +881,73 @@ def test_mv_filtering(self): query_statement = SimpleStatement("SELECT * FROM {0}.alltimehigh WHERE game='Coup'".format(self.keyspace_name), consistency_level=ConsistencyLevel.QUORUM) results = self.session.execute(query_statement) - self.assertEqual(results[0].game, 'Coup') - self.assertEqual(results[0].year, 2015) - self.assertEqual(results[0].month, 5) - self.assertEqual(results[0].day, 1) - self.assertEqual(results[0].score, 4000) - self.assertEqual(results[0].user, "pcmanus") + self.assertEquals(results[0].game, 'Coup') + self.assertEquals(results[0].year, 2015) + self.assertEquals(results[0].month, 5) + self.assertEquals(results[0].day, 1) + self.assertEquals(results[0].score, 4000) + self.assertEquals(results[0].user, "pcmanus") # Test prepared statement and daily high filtering prepared_query = self.session.prepare("SELECT * FROM {0}.dailyhigh WHERE game=? AND year=? AND month=? and day=?".format(self.keyspace_name)) bound_query = prepared_query.bind(("Coup", 2015, 6, 2)) results = self.session.execute(bound_query) - self.assertEqual(results[0].game, 'Coup') - self.assertEqual(results[0].year, 2015) - self.assertEqual(results[0].month, 6) - self.assertEqual(results[0].day, 2) - self.assertEqual(results[0].score, 2000) - self.assertEqual(results[0].user, "pcmanus") - - self.assertEqual(results[1].game, 'Coup') - self.assertEqual(results[1].year, 2015) - self.assertEqual(results[1].month, 6) - self.assertEqual(results[1].day, 2) - self.assertEqual(results[1].score, 1000) - self.assertEqual(results[1].user, "tjake") + self.assertEquals(results[0].game, 'Coup') + self.assertEquals(results[0].year, 2015) + self.assertEquals(results[0].month, 6) + self.assertEquals(results[0].day, 2) + self.assertEquals(results[0].score, 2000) + self.assertEquals(results[0].user, "pcmanus") + + self.assertEquals(results[1].game, 'Coup') + self.assertEquals(results[1].year, 2015) + self.assertEquals(results[1].month, 6) + self.assertEquals(results[1].day, 2) + self.assertEquals(results[1].score, 1000) + self.assertEquals(results[1].user, "tjake") # Test montly high range queries prepared_query = self.session.prepare("SELECT * FROM {0}.monthlyhigh WHERE game=? AND year=? AND month=? and score >= ? and score <= ?".format(self.keyspace_name)) bound_query = prepared_query.bind(("Coup", 2015, 6, 2500, 3500)) results = self.session.execute(bound_query) - self.assertEqual(results[0].game, 'Coup') - self.assertEqual(results[0].year, 2015) - self.assertEqual(results[0].month, 6) - self.assertEqual(results[0].day, 20) - self.assertEqual(results[0].score, 3500) - self.assertEqual(results[0].user, "jbellis") - - self.assertEqual(results[1].game, 'Coup') - self.assertEqual(results[1].year, 2015) - self.assertEqual(results[1].month, 6) - self.assertEqual(results[1].day, 9) - self.assertEqual(results[1].score, 2700) - self.assertEqual(results[1].user, "jmckenzie") - - self.assertEqual(results[2].game, 'Coup') - self.assertEqual(results[2].year, 2015) - self.assertEqual(results[2].month, 6) - self.assertEqual(results[2].day, 1) - self.assertEqual(results[2].score, 2500) - self.assertEqual(results[2].user, "iamaleksey") + self.assertEquals(results[0].game, 'Coup') + self.assertEquals(results[0].year, 2015) + self.assertEquals(results[0].month, 6) + self.assertEquals(results[0].day, 20) + self.assertEquals(results[0].score, 3500) + self.assertEquals(results[0].user, "jbellis") + + self.assertEquals(results[1].game, 'Coup') + self.assertEquals(results[1].year, 2015) + self.assertEquals(results[1].month, 6) + self.assertEquals(results[1].day, 9) + self.assertEquals(results[1].score, 2700) + self.assertEquals(results[1].user, "jmckenzie") + + self.assertEquals(results[2].game, 'Coup') + self.assertEquals(results[2].year, 2015) + self.assertEquals(results[2].month, 6) + self.assertEquals(results[2].day, 1) + self.assertEquals(results[2].score, 2500) + self.assertEquals(results[2].user, "iamaleksey") # Test filtered user high scores query_statement = SimpleStatement("SELECT * FROM {0}.filtereduserhigh WHERE game='Chess'".format(self.keyspace_name), consistency_level=ConsistencyLevel.QUORUM) results = self.session.execute(query_statement) - self.assertEqual(results[0].game, 'Chess') - self.assertEqual(results[0].year, 2015) - self.assertEqual(results[0].month, 6) - self.assertEqual(results[0].day, 21) - self.assertEqual(results[0].score, 3500) - self.assertEqual(results[0].user, "jbellis") - - self.assertEqual(results[1].game, 'Chess') - self.assertEqual(results[1].year, 2015) - self.assertEqual(results[1].month, 1) - self.assertEqual(results[1].day, 25) - self.assertEqual(results[1].score, 3200) - self.assertEqual(results[1].user, "pcmanus") + self.assertEquals(results[0].game, 'Chess') + self.assertEquals(results[0].year, 2015) + self.assertEquals(results[0].month, 6) + self.assertEquals(results[0].day, 21) + self.assertEquals(results[0].score, 3500) + self.assertEquals(results[0].user, "jbellis") + + self.assertEquals(results[1].game, 'Chess') + self.assertEquals(results[1].year, 2015) + self.assertEquals(results[1].month, 1) + self.assertEquals(results[1].day, 25) + self.assertEquals(results[1].score, 3200) + self.assertEquals(results[1].user, "pcmanus") class UnicodeQueryTest(BasicSharedKeyspaceUnitTestCase): diff --git a/tests/integration/standard/test_types.py b/tests/integration/standard/test_types.py index f959d4d9f9..736e7957e2 100644 --- a/tests/integration/standard/test_types.py +++ b/tests/integration/standard/test_types.py @@ -31,7 +31,7 @@ from tests.unit.cython.utils import cythontest from tests.integration import use_singledc, PROTOCOL_VERSION, execute_until_pass, notprotocolv1, \ - BasicSharedKeyspaceUnitTestCase, greaterthancass21, lessthancass30 + BasicSharedKeyspaceUnitTestCase, greaterthancass20, lessthancass30 from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES, COLLECTION_TYPES, \ get_sample, get_collection_sample @@ -796,7 +796,7 @@ def test_cython_decimal(self): class TypeTestsProtocol(BasicSharedKeyspaceUnitTestCase): - @greaterthancass21 + @greaterthancass20 @lessthancass30 def test_nested_types_with_protocol_version(self): """ diff --git a/tests/integration/standard/utils.py b/tests/integration/standard/utils.py index 917b3a7f6e..4011047fc8 100644 --- a/tests/integration/standard/utils.py +++ b/tests/integration/standard/utils.py @@ -4,7 +4,6 @@ from tests.integration.datatype_utils import PRIMITIVE_DATATYPES, get_sample - def create_table_with_all_types(table_name, session, N): """ Method that given a table_name and session construct a table that contains diff --git a/tests/unit/cqlengine/__init__.py b/tests/unit/cqlengine/__init__.py deleted file mode 100644 index 87fc3685e0..0000000000 --- a/tests/unit/cqlengine/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2013-2016 DataStax, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/tests/unit/cqlengine/test_columns.py b/tests/unit/cqlengine/test_columns.py deleted file mode 100644 index 181c103515..0000000000 --- a/tests/unit/cqlengine/test_columns.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2013-2016 DataStax, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -try: - import unittest2 as unittest -except ImportError: - import unittest # noqa - -from cassandra.cqlengine.columns import Column - - -class ColumnTest(unittest.TestCase): - - def test_comparisons(self): - c0 = Column() - c1 = Column() - self.assertEqual(c1.position - c0.position, 1) - - # __ne__ - self.assertNotEqual(c0, c1) - self.assertNotEqual(c0, object()) - - # __eq__ - self.assertEqual(c0, c0) - self.assertFalse(c0 == object()) - - # __lt__ - self.assertLess(c0, c1) - try: - c0 < object() # this raises for Python 3 - except TypeError: - pass - - # __le__ - self.assertLessEqual(c0, c1) - self.assertLessEqual(c0, c0) - try: - c0 <= object() # this raises for Python 3 - except TypeError: - pass - - # __gt__ - self.assertGreater(c1, c0) - try: - c1 > object() # this raises for Python 3 - except TypeError: - pass - - # __ge__ - self.assertGreaterEqual(c1, c0) - self.assertGreaterEqual(c1, c1) - try: - c1 >= object() # this raises for Python 3 - except TypeError: - pass - - diff --git a/tests/unit/test_concurrent.py b/tests/unit/test_concurrent.py index 948f6f2502..a535bf2260 100644 --- a/tests/unit/test_concurrent.py +++ b/tests/unit/test_concurrent.py @@ -17,15 +17,12 @@ import unittest2 as unittest except ImportError: import unittest # noqa - from itertools import cycle from mock import Mock import time import threading from six.moves.queue import PriorityQueue -import sys -from cassandra.cluster import Cluster, Session from cassandra.concurrent import execute_concurrent, execute_concurrent_with_args @@ -232,19 +229,3 @@ def validate_result_ordering(self, results): current_time_added = list(result)[0] self.assertLess(last_time_added, current_time_added) last_time_added = current_time_added - - def test_recursion_limited(self): - """ - Verify that recursion is controlled when raise_on_first_error=False and something is wrong with the query. - - PYTHON-585 - """ - max_recursion = sys.getrecursionlimit() - s = Session(Cluster(), []) - self.assertRaises(TypeError, execute_concurrent_with_args, s, "doesn't matter", [('param',)] * max_recursion, raise_on_first_error=True) - - results = execute_concurrent_with_args(s, "doesn't matter", [('param',)] * max_recursion, raise_on_first_error=False) # previously - self.assertEqual(len(results), max_recursion) - for r in results: - self.assertFalse(r[0]) - self.assertIsInstance(r[1], TypeError) diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py index b8cb640b46..2ac10a590f 100644 --- a/tests/unit/test_connection.py +++ b/tests/unit/test_connection.py @@ -112,7 +112,7 @@ def test_negative_body_length(self, *args): def test_unsupported_cql_version(self, *args): c = self.make_connection() - c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message, [])} + c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)} c.defunct = Mock() c.cql_version = "3.0.3" @@ -135,7 +135,7 @@ def test_unsupported_cql_version(self, *args): def test_prefer_lz4_compression(self, *args): c = self.make_connection() - c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message, [])} + c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)} c.defunct = Mock() c.cql_version = "3.0.3" @@ -158,7 +158,7 @@ def test_prefer_lz4_compression(self, *args): def test_requested_compression_not_available(self, *args): c = self.make_connection() - c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message, [])} + c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)} c.defunct = Mock() # request lz4 compression c.compression = "lz4" @@ -188,7 +188,7 @@ def test_requested_compression_not_available(self, *args): def test_use_requested_compression(self, *args): c = self.make_connection() - c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message, [])} + c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)} c.defunct = Mock() # request snappy compression c.compression = "snappy" diff --git a/tests/unit/test_host_connection_pool.py b/tests/unit/test_host_connection_pool.py index 5fe230f402..fb0ca21711 100644 --- a/tests/unit/test_host_connection_pool.py +++ b/tests/unit/test_host_connection_pool.py @@ -165,8 +165,7 @@ def test_spawn_when_at_max(self): def test_return_defunct_connection(self): host = Mock(spec=Host, address='ip1') session = self.make_session() - conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, - max_request_id=100, signaled_error=False) + conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100) session.cluster.connection_factory.return_value = conn pool = HostConnectionPool(host, HostDistance.LOCAL, session) @@ -178,14 +177,14 @@ def test_return_defunct_connection(self): pool.return_connection(conn) # the connection should be closed a new creation scheduled - self.assertTrue(session.submit.call_args) + conn.close.assert_called_once() + session.submit.assert_called_once() self.assertFalse(pool.is_shutdown) def test_return_defunct_connection_on_down_host(self): host = Mock(spec=Host, address='ip1') session = self.make_session() - conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, - max_request_id=100, signaled_error=False) + conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100, signaled_error=False) session.cluster.connection_factory.return_value = conn pool = HostConnectionPool(host, HostDistance.LOCAL, session) @@ -197,15 +196,15 @@ def test_return_defunct_connection_on_down_host(self): pool.return_connection(conn) # the connection should be closed a new creation scheduled - self.assertTrue(session.cluster.signal_connection_failure.call_args) - self.assertTrue(conn.close.call_args) + session.cluster.signal_connection_failure.assert_called_once() + conn.close.assert_called_once() self.assertFalse(session.submit.called) self.assertTrue(pool.is_shutdown) def test_return_closed_connection(self): host = Mock(spec=Host, address='ip1') session = self.make_session() - conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=True, max_request_id=100, signaled_error=False) + conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=True, max_request_id=100) session.cluster.connection_factory.return_value = conn pool = HostConnectionPool(host, HostDistance.LOCAL, session) @@ -217,7 +216,7 @@ def test_return_closed_connection(self): pool.return_connection(conn) # a new creation should be scheduled - self.assertTrue(session.submit.call_args) + session.submit.assert_called_once() self.assertFalse(pool.is_shutdown) def test_host_instantiations(self): diff --git a/tests/unit/test_parameter_binding.py b/tests/unit/test_parameter_binding.py index 555dfe3834..d48b5d9573 100644 --- a/tests/unit/test_parameter_binding.py +++ b/tests/unit/test_parameter_binding.py @@ -91,7 +91,7 @@ def setUpClass(cls): routing_key_indexes=[1, 0], query=None, keyspace='keyspace', - protocol_version=cls.protocol_version, result_metadata=None) + protocol_version=cls.protocol_version) cls.bound = BoundStatement(prepared_statement=cls.prepared) def test_invalid_argument_type(self): @@ -130,8 +130,7 @@ def test_inherit_fetch_size(self): routing_key_indexes=[], query=None, keyspace=keyspace, - protocol_version=self.protocol_version, - result_metadata=None) + protocol_version=self.protocol_version) prepared_statement.fetch_size = 1234 bound_statement = BoundStatement(prepared_statement=prepared_statement) self.assertEqual(1234, bound_statement.fetch_size) @@ -164,8 +163,7 @@ def test_values_none(self): routing_key_indexes=[], query=None, keyspace='whatever', - protocol_version=self.protocol_version, - result_metadata=None) + protocol_version=self.protocol_version) bound = prepared_statement.bind(None) self.assertListEqual(bound.values, []) diff --git a/tests/unit/test_response_future.py b/tests/unit/test_response_future.py index 88b08af878..ad5bb3e93b 100644 --- a/tests/unit/test_response_future.py +++ b/tests/unit/test_response_future.py @@ -67,7 +67,7 @@ def test_result_message(self): rf.session._pools.get.assert_called_once_with('ip1') pool.borrow_connection.assert_called_once_with(timeout=ANY) - connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) + connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message) rf._set_result(self.make_mock_response([{'col': 'val'}])) result = rf.result() @@ -192,7 +192,7 @@ def test_retry_policy_says_retry(self): rf.session._pools.get.assert_called_once_with('ip1') pool.borrow_connection.assert_called_once_with(timeout=ANY) - connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) + connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message) result = Mock(spec=UnavailableErrorMessage, info={}) rf._set_result(result) @@ -210,7 +210,7 @@ def test_retry_policy_says_retry(self): # an UnavailableException rf.session._pools.get.assert_called_with('ip1') pool.borrow_connection.assert_called_with(timeout=ANY) - connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) + connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message) def test_retry_with_different_host(self): session = self.make_session() @@ -225,7 +225,7 @@ def test_retry_with_different_host(self): rf.session._pools.get.assert_called_once_with('ip1') pool.borrow_connection.assert_called_once_with(timeout=ANY) - connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) + connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message) self.assertEqual(ConsistencyLevel.QUORUM, rf.message.consistency_level) result = Mock(spec=OverloadedErrorMessage, info={}) @@ -243,7 +243,7 @@ def test_retry_with_different_host(self): # it should try with a different host rf.session._pools.get.assert_called_with('ip2') pool.borrow_connection.assert_called_with(timeout=ANY) - connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) + connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message) # the consistency level should be the same self.assertEqual(ConsistencyLevel.QUORUM, rf.message.consistency_level) @@ -480,7 +480,7 @@ def test_prepared_query_not_found(self): result = Mock(spec=PreparedQueryNotFound, info='a' * 16) rf._set_result(result) - self.assertTrue(session.submit.call_args) + session.submit.assert_called_once() args, kwargs = session.submit.call_args self.assertEqual(rf._reprepare, args[-2]) self.assertIsInstance(args[-1], PrepareMessage) From 0a1c61fd72d5ef18f3d47d7208edc34f4dc4defd Mon Sep 17 00:00:00 2001 From: Adam Holmberg Date: Fri, 12 Aug 2016 13:57:44 -0500 Subject: [PATCH 02/10] Revert "Revert "Merge tag '3.6.0' into cassandra-test"" This reverts commit f433a80cb0d8d691b25a4d826866fda7916cd2f1. --- CHANGELOG.rst | 31 +++ cassandra/__init__.py | 2 +- cassandra/cluster.py | 92 +++++--- cassandra/concurrent.py | 14 +- cassandra/connection.py | 18 +- cassandra/cqlengine/columns.py | 92 +++++++- cassandra/cqlengine/management.py | 36 ++- cassandra/cqlengine/models.py | 33 ++- cassandra/cqlengine/query.py | 74 +++++- cassandra/cqlengine/statements.py | 4 +- cassandra/cqltypes.py | 24 +- cassandra/encoder.py | 3 +- cassandra/io/eventletreactor.py | 14 -- cassandra/io/geventreactor.py | 18 +- cassandra/io/libevreactor.py | 10 +- cassandra/metadata.py | 27 +-- cassandra/metrics.py | 35 ++- cassandra/numpy_parser.pyx | 30 ++- cassandra/protocol.py | 59 +++-- cassandra/query.py | 31 ++- cassandra/row_parser.pyx | 4 +- cassandra/type_codes.py | 1 - cassandra/util.py | 4 + docs.yaml | 5 + docs/api/cassandra/cqlengine/models.rst | 6 +- docs/api/cassandra/cqlengine/query.rst | 8 + docs/cqlengine/queryset.rst | 36 +++ docs/getting_started.rst | 2 +- test-requirements.txt | 2 +- tests/integration/__init__.py | 84 ++++++- tests/integration/cqlengine/__init__.py | 2 +- .../cqlengine/columns/test_validation.py | 223 +++++++++++++++++- .../model/test_class_construction.py | 23 +- .../integration/cqlengine/model/test_model.py | 36 ++- .../cqlengine/model/test_model_io.py | 63 ++++- .../cqlengine/model/test_updates.py | 36 ++- .../integration/cqlengine/query/test_named.py | 4 +- .../cqlengine/query/test_queryoperators.py | 6 +- .../cqlengine/query/test_queryset.py | 69 +++++- .../cqlengine/test_context_query.py | 127 ++++++++++ .../cqlengine/test_lwt_conditional.py | 15 ++ tests/integration/cqlengine/test_ttl.py | 43 +++- tests/integration/long/test_ssl.py | 42 +++- tests/integration/standard/test_cluster.py | 89 ++++++- tests/integration/standard/test_connection.py | 2 +- .../standard/test_custom_protocol_handler.py | 4 +- .../standard/test_cython_protocol_handlers.py | 91 +++---- tests/integration/standard/test_metadata.py | 134 +++++++---- tests/integration/standard/test_metrics.py | 117 ++++++++- tests/integration/standard/test_query.py | 143 ++++++----- tests/integration/standard/test_types.py | 4 +- tests/integration/standard/utils.py | 1 + tests/unit/cqlengine/__init__.py | 14 ++ tests/unit/cqlengine/test_columns.py | 68 ++++++ tests/unit/test_concurrent.py | 19 ++ tests/unit/test_connection.py | 8 +- tests/unit/test_host_connection_pool.py | 17 +- tests/unit/test_parameter_binding.py | 8 +- tests/unit/test_response_future.py | 12 +- 59 files changed, 1758 insertions(+), 461 deletions(-) create mode 100644 tests/integration/cqlengine/test_context_query.py create mode 100644 tests/unit/cqlengine/__init__.py create mode 100644 tests/unit/cqlengine/test_columns.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 273657131a..3db920828b 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,34 @@ +3.6.0 +===== +August 1, 2016 + +Features +-------- +* Handle null values in NumpyProtocolHandler (PYTHON-553) +* Collect greplin scales stats per cluster (PYTHON-561) +* Update mock unit test dependency requirement (PYTHON-591) +* Handle Missing CompositeType metadata following C* upgrade (PYTHON-562) +* Improve Host.is_up state for HostDistance.IGNORED hosts (PYTHON-551) +* Utilize v2 protocol's ability to skip result set metadata for prepared statement execution (PYTHON-71) +* Return from Cluster.connect() when first contact point connection(pool) is opened (PYTHON-105) +* cqlengine: Add ContextQuery to allow cqlengine models to switch the keyspace context easily (PYTHON-598) + +Bug Fixes +--------- +* Fix geventreactor with SSL support (PYTHON-600) +* Don't downgrade protocol version if explicitly set (PYTHON-537) +* Nonexistent contact point tries to connect indefinitely (PYTHON-549) +* Execute_concurrent can exceed max recursion depth in failure mode (PYTHON-585) +* Libev loop shutdown race (PYTHON-578) +* Include aliases in DCT type string (PYTHON-579) +* cqlengine: Comparison operators for Columns (PYTHON-595) +* cqlengine: disentangle default_time_to_live table option from model query default TTL (PYTHON-538) +* cqlengine: pk__token column name issue with the equality operator (PYTHON-584) +* cqlengine: Fix "__in" filtering operator converts True to string "True" automatically (PYTHON-596) +* cqlengine: Avoid LWTExceptions when updating columns that are part of the condition (PYTHON-580) +* cqlengine: Cannot execute a query when the filter contains all columns (PYTHON-599) +* cqlengine: routing key computation issue when a primary key column is overriden by model inheritance (PYTHON-576) + 3.5.0 ===== June 27, 2016 diff --git a/cassandra/__init__.py b/cassandra/__init__.py index c8212c70e3..1a02d8a892 100644 --- a/cassandra/__init__.py +++ b/cassandra/__init__.py @@ -22,7 +22,7 @@ def emit(self, record): logging.getLogger('cassandra').addHandler(NullHandler()) -__version_info__ = (3, 5, 0) +__version_info__ = (3, 6, 0) __version__ = '.'.join(map(str, __version_info__)) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 99509d2233..536ae71c14 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -20,7 +20,7 @@ import atexit from collections import defaultdict, Mapping -from concurrent.futures import ThreadPoolExecutor, wait as wait_futures +from concurrent.futures import ThreadPoolExecutor, FIRST_COMPLETED, wait as wait_futures from copy import copy from functools import partial, wraps from itertools import groupby, count @@ -356,10 +356,11 @@ class Cluster(object): """ The maximum version of the native protocol to use. - The driver will automatically downgrade version based on a negotiation with - the server, but it is most efficient to set this to the maximum supported - by your version of Cassandra. Setting this will also prevent conflicting - versions negotiated if your cluster is upgraded. + If not set in the constructor, the driver will automatically downgrade + version based on a negotiation with the server, but it is most efficient + to set this to the maximum supported by your version of Cassandra. + Setting this will also prevent conflicting versions negotiated if your + cluster is upgraded. Version 2 of the native protocol adds support for lightweight transactions, batch operations, and automatic query paging. The v2 protocol is @@ -388,6 +389,8 @@ class Cluster(object): +-------------------+-------------------+ | 2.2 | 1, 2, 3, 4 | +-------------------+-------------------+ + | 3.x | 3, 4 | + +-------------------+-------------------+ """ compression = True @@ -719,6 +722,7 @@ def token_metadata_enabled(self, enabled): _prepared_statements = None _prepared_statement_lock = None _idle_heartbeat = None + _protocol_version_explicit = False _user_types = None """ @@ -742,7 +746,7 @@ def __init__(self, ssl_options=None, sockopts=None, cql_version=None, - protocol_version=4, + protocol_version=_NOT_SET, executor_threads=2, max_schema_agreement_wait=10, control_connection_timeout=2.0, @@ -777,7 +781,11 @@ def __init__(self, for endpoint in socket.getaddrinfo(a, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM)] self.compression = compression - self.protocol_version = protocol_version + + if protocol_version is not _NOT_SET: + self.protocol_version = protocol_version + self._protocol_version_explicit = True + self.auth_provider = auth_provider if load_balancing_policy is not None: @@ -1117,6 +1125,9 @@ def _make_connection_kwargs(self, address, kwargs_dict): return kwargs_dict def protocol_downgrade(self, host_addr, previous_version): + if self._protocol_version_explicit: + raise DriverException("ProtocolError returned from server while using explicitly set client protocol_version %d" % (previous_version,)) + new_version = previous_version - 1 if new_version < self.protocol_version: if new_version >= MIN_SUPPORTED_VERSION: @@ -1127,7 +1138,7 @@ def protocol_downgrade(self, host_addr, previous_version): else: raise DriverException("Cannot downgrade protocol version (%d) below minimum supported version: %d" % (new_version, MIN_SUPPORTED_VERSION)) - def connect(self, keyspace=None): + def connect(self, keyspace=None, wait_for_all_pools=False): """ Creates and returns a new :class:`~.Session` object. If `keyspace` is specified, that keyspace will be the default keyspace for @@ -1154,6 +1165,13 @@ def connect(self, keyspace=None): try: self.control_connection.connect() + + # we set all contact points up for connecting, but we won't infer state after this + for address in self.contact_points_resolved: + h = self.metadata.get_host(address) + if h and self.profile_manager.distance(h) == HostDistance.IGNORED: + h.is_up = None + log.debug("Control connection created") except Exception: log.exception("Control connection failed to connect, " @@ -1167,9 +1185,9 @@ def connect(self, keyspace=None): self._idle_heartbeat = ConnectionHeartbeat(self.idle_heartbeat_interval, self.get_connection_holders) self._is_setup = True - session = self._new_session() - if keyspace: - session.set_keyspace(keyspace) + session = self._new_session(keyspace) + if wait_for_all_pools: + wait_futures(session._initial_connect_futures) return session def get_connection_holders(self): @@ -1213,8 +1231,8 @@ def __enter__(self): def __exit__(self, *args): self.shutdown() - def _new_session(self): - session = Session(self, self.metadata.all_hosts()) + def _new_session(self, keyspace): + session = Session(self, self.metadata.all_hosts(), keyspace) self._session_register_user_types(session) self.sessions.add(session) return session @@ -1334,6 +1352,7 @@ def on_up(self, host): else: if not have_future: with host.lock: + host.set_up() host._currently_handling_node_up = False # for testing purposes @@ -1372,10 +1391,11 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False): return with host.lock: - if (not host.is_up and not expect_host_to_be_down) or host.is_currently_reconnecting(): + was_up = host.is_up + host.set_down() + if (not was_up and not expect_host_to_be_down) or host.is_currently_reconnecting(): return - host.set_down() log.warning("Host %s has been marked down", host) @@ -1888,9 +1908,10 @@ def default_serial_consistency_level(self, cl): _profile_manager = None _metrics = None - def __init__(self, cluster, hosts): + def __init__(self, cluster, hosts, keyspace=None): self.cluster = cluster self.hosts = hosts + self.keyspace = keyspace self._lock = RLock() self._pools = {} @@ -1901,14 +1922,13 @@ def __init__(self, cluster, hosts): self.encoder = Encoder() # create connection pools in parallel - futures = [] + self._initial_connect_futures = set() for host in hosts: future = self.add_or_renew_pool(host, is_host_addition=False) - if future is not None: - futures.append(future) + if future: + self._initial_connect_futures.add(future) + wait_futures(self._initial_connect_futures, return_when=FIRST_COMPLETED) - for future in futures: - future.result() def execute(self, query, parameters=None, timeout=_NOT_SET, trace=False, custom_payload=None, execution_profile=EXEC_PROFILE_DEFAULT): """ @@ -2045,11 +2065,11 @@ def _create_response_future(self, query, parameters, trace, custom_payload, time query_string, cl, serial_cl, fetch_size, timestamp=timestamp) elif isinstance(query, BoundStatement): + prepared_statement = query.prepared_statement message = ExecuteMessage( - query.prepared_statement.query_id, query.values, cl, + prepared_statement.query_id, query.values, cl, serial_cl, fetch_size, - timestamp=timestamp) - prepared_statement = query.prepared_statement + timestamp=timestamp, skip_meta=bool(prepared_statement.result_metadata)) elif isinstance(query, BatchStatement): if self._protocol_version < 2: raise UnsupportedOperation( @@ -2124,14 +2144,14 @@ def prepare(self, query, custom_payload=None): future = ResponseFuture(self, message, query=None, timeout=self.default_timeout) try: future.send_request() - query_id, column_metadata, pk_indexes = future.result() + query_id, bind_metadata, pk_indexes, result_metadata = future.result() except Exception: log.exception("Error preparing query:") raise prepared_statement = PreparedStatement.from_message( - query_id, column_metadata, pk_indexes, self.cluster.metadata, query, self.keyspace, - self._protocol_version) + query_id, bind_metadata, pk_indexes, self.cluster.metadata, query, self.keyspace, + self._protocol_version, result_metadata) prepared_statement.custom_payload = future.custom_payload self.cluster.add_prepared(query_id, prepared_statement) @@ -2189,7 +2209,7 @@ def shutdown(self): else: self.is_shutdown = True - for pool in self._pools.values(): + for pool in list(self._pools.values()): pool.shutdown() def __enter__(self): @@ -2774,9 +2794,8 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, for old_host in self._cluster.metadata.all_hosts(): if old_host.address != connection.host and old_host.address not in found_hosts: should_rebuild_token_map = True - if old_host.address not in self._cluster.contact_points: - log.debug("[control connection] Removing host not found in peers metadata: %r", old_host) - self._cluster.remove_host(old_host) + log.debug("[control connection] Removing host not found in peers metadata: %r", old_host) + self._cluster.remove_host(old_host) log.debug("[control connection] Finished fetching ring info") if partitioner and should_rebuild_token_map: @@ -2929,14 +2948,13 @@ def _get_schema_mismatches(self, peers_result, local_result, local_address): if local_row.get("schema_version"): versions[local_row.get("schema_version")].add(local_address) - pm = self._cluster.profile_manager for row in peers_result: schema_ver = row.get('schema_version') if not schema_ver: continue addr = self._rpc_from_peer_row(row) peer = self._cluster.metadata.get_host(addr) - if peer and peer.is_up and pm.distance(peer) != HostDistance.IGNORED: + if peer and peer.is_up is not False: versions[schema_ver].add(addr) if len(versions) == 1: @@ -3254,7 +3272,9 @@ def _query(self, host, message=None, cb=None): # TODO get connectTimeout from cluster settings connection, request_id = pool.borrow_connection(timeout=2.0) self._connection = connection - connection.send_msg(message, request_id, cb=cb, encoder=self._protocol_handler.encode_message, decoder=self._protocol_handler.decode_message) + result_meta = self.prepared_statement.result_metadata if self.prepared_statement else [] + connection.send_msg(message, request_id, cb=cb, encoder=self._protocol_handler.encode_message, decoder=self._protocol_handler.decode_message, + result_metadata=result_meta) return request_id except NoConnectionsAvailable as exc: log.debug("All connections for host %s are at capacity, moving to the next host", host) @@ -3757,8 +3777,8 @@ def add_callbacks(self, callback, errback, def clear_callbacks(self): with self._callback_lock: - self._callback = [] - self._errback = [] + self._callbacks = [] + self._errbacks = [] def __str__(self): result = "(no result yet)" if self._final_result is _NOT_SET else self._final_result diff --git a/cassandra/concurrent.py b/cassandra/concurrent.py index 48cbab3e24..a08c0292e3 100644 --- a/cassandra/concurrent.py +++ b/cassandra/concurrent.py @@ -94,6 +94,8 @@ def execute_concurrent(session, statements_and_parameters, concurrency=100, rais class _ConcurrentExecutor(object): + max_error_recursion = 100 + def __init__(self, session, statements_and_params): self.session = session self._enum_statements = enumerate(iter(statements_and_params)) @@ -102,6 +104,7 @@ def __init__(self, session, statements_and_params): self._results_queue = [] self._current = 0 self._exec_count = 0 + self._exec_depth = 0 def execute(self, concurrency, fail_fast): self._fail_fast = fail_fast @@ -125,6 +128,7 @@ def _execute_next(self): pass def _execute(self, idx, statement, params): + self._exec_depth += 1 try: future = self.session.execute_async(statement, params, timeout=None) args = (future, idx) @@ -135,7 +139,15 @@ def _execute(self, idx, statement, params): # exc_info with fail_fast to preserve stack trace info when raising on the client thread # (matches previous behavior -- not sure why we wouldn't want stack trace in the other case) e = sys.exc_info() if self._fail_fast and six.PY2 else exc - self._put_result(e, idx, False) + + # If we're not failing fast and all executions are raising, there is a chance of recursing + # here as subsequent requests are attempted. If we hit this threshold, schedule this result/retry + # and let the event loop thread return. + if self._exec_depth < self.max_error_recursion: + self._put_result(e, idx, False) + else: + self.session.submit(self._put_result, e, idx, False) + self._exec_depth -= 1 def _on_success(self, result, future, idx): future.clear_callbacks() diff --git a/cassandra/connection.py b/cassandra/connection.py index f43edc4b5d..11da8a4afe 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -149,8 +149,8 @@ class ProtocolVersionUnsupported(ConnectionException): Server rejected startup message due to unsupported protocol version """ def __init__(self, host, startup_version): - super(ProtocolVersionUnsupported, self).__init__("Unsupported protocol version on %s: %d", - (host, startup_version)) + msg = "Unsupported protocol version on %s: %d" % (host, startup_version) + super(ProtocolVersionUnsupported, self).__init__(msg, host) self.startup_version = startup_version @@ -345,6 +345,7 @@ def _connect_socket(self): self._socket = self._ssl_impl.wrap_socket(self._socket, **self.ssl_options) self._socket.settimeout(self.connect_timeout) self._socket.connect(sockaddr) + self._socket.settimeout(None) if self._check_hostname: ssl.match_hostname(self._socket.getpeercert(), self.host) sockerr = None @@ -404,7 +405,7 @@ def try_callback(cb): id(self), self.host, exc_info=True) # run first callback from this thread to ensure pool state before leaving - cb, _ = requests.popitem()[1] + cb, _, _ = requests.popitem()[1] try_callback(cb) if not requests: @@ -414,7 +415,7 @@ def try_callback(cb): # The default callback and retry logic is fairly expensive -- we don't # want to tie up the event thread when there are many requests def err_all_callbacks(): - for cb, _ in requests.values(): + for cb, _, _ in requests.values(): try_callback(cb) if len(requests) < Connection.CALLBACK_ERR_THREAD_THRESHOLD: err_all_callbacks() @@ -445,7 +446,7 @@ def handle_pushed(self, response): except Exception: log.exception("Pushed event handler errored, ignoring:") - def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message): + def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=None): if self.is_defunct: raise ConnectionShutdown("Connection to %s is defunct" % self.host) elif self.is_closed: @@ -453,7 +454,7 @@ def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message, # queue the decoder function with the request # this allows us to inject custom functions per request to encode, decode messages - self._requests[request_id] = (cb, decoder) + self._requests[request_id] = (cb, decoder, result_metadata) self.push(encoder(msg, request_id, self.protocol_version, compressor=self.compressor)) return request_id @@ -578,8 +579,9 @@ def process_msg(self, header, body): if stream_id < 0: callback = None decoder = ProtocolHandler.decode_message + result_metadata = None else: - callback, decoder = self._requests.pop(stream_id, None) + callback, decoder, result_metadata = self._requests.pop(stream_id) with self.lock: self.request_ids.append(stream_id) @@ -587,7 +589,7 @@ def process_msg(self, header, body): try: response = decoder(header.version, self.user_type_map, stream_id, - header.flags, header.opcode, body, self.decompressor) + header.flags, header.opcode, body, self.decompressor, result_metadata) except Exception as exc: log.exception("Error decoding response from Cassandra. " "%s; buffer: %r", header, self._iobuf.getvalue()) diff --git a/cassandra/cqlengine/columns.py b/cassandra/cqlengine/columns.py index 14b70915a7..0bb52d6bff 100644 --- a/cassandra/cqlengine/columns.py +++ b/cassandra/cqlengine/columns.py @@ -168,6 +168,36 @@ def __init__(self, self.position = Column.instance_counter Column.instance_counter += 1 + def __ne__(self, other): + if isinstance(other, Column): + return self.position != other.position + return NotImplemented + + def __eq__(self, other): + if isinstance(other, Column): + return self.position == other.position + return NotImplemented + + def __lt__(self, other): + if isinstance(other, Column): + return self.position < other.position + return NotImplemented + + def __le__(self, other): + if isinstance(other, Column): + return self.position <= other.position + return NotImplemented + + def __gt__(self, other): + if isinstance(other, Column): + return self.position > other.position + return NotImplemented + + def __ge__(self, other): + if isinstance(other, Column): + return self.position >= other.position + return NotImplemented + def validate(self, value): """ Returns a cleaned and validated value. Raises a ValidationError @@ -279,13 +309,6 @@ def to_database(self, value): Bytes = Blob -class Ascii(Column): - """ - Stores a US-ASCII character string - """ - db_type = 'ascii' - - class Inet(Column): """ Stores an IP address in IPv4 or IPv6 format @@ -305,25 +328,68 @@ def __init__(self, min_length=None, max_length=None, **kwargs): Defaults to 1 if this is a ``required`` column. Otherwise, None. :param int max_length: Sets the maximum length of this string, for validation purposes. """ - self.min_length = min_length or (1 if kwargs.get('required', False) else None) + self.min_length = ( + 1 if not min_length and kwargs.get('required', False) + else min_length) self.max_length = max_length + + if self.min_length is not None: + if self.min_length < 0: + raise ValueError( + 'Minimum length is not allowed to be negative.') + + if self.max_length is not None: + if self.max_length < 0: + raise ValueError( + 'Maximum length is not allowed to be negative.') + + if self.min_length is not None and self.max_length is not None: + if self.max_length < self.min_length: + raise ValueError( + 'Maximum length must be greater or equal ' + 'to minimum length.') + super(Text, self).__init__(**kwargs) def validate(self, value): value = super(Text, self).validate(value) - if value is None: - return if not isinstance(value, (six.string_types, bytearray)) and value is not None: raise ValidationError('{0} {1} is not a string'.format(self.column_name, type(value))) - if self.max_length: - if len(value) > self.max_length: + if self.max_length is not None: + if value and len(value) > self.max_length: raise ValidationError('{0} is longer than {1} characters'.format(self.column_name, self.max_length)) if self.min_length: - if len(value) < self.min_length: + if (self.min_length and not value) or len(value) < self.min_length: raise ValidationError('{0} is shorter than {1} characters'.format(self.column_name, self.min_length)) return value +class Ascii(Text): + """ + Stores a US-ASCII character string + """ + db_type = 'ascii' + + def validate(self, value): + """ Only allow ASCII and None values. + + Check against US-ASCII, a.k.a. 7-bit ASCII, a.k.a. ISO646-US, a.k.a. + the Basic Latin block of the Unicode character set. + + Source: https://github.com/apache/cassandra/blob + /3dcbe90e02440e6ee534f643c7603d50ca08482b/src/java/org/apache/cassandra + /serializers/AsciiSerializer.java#L29 + """ + value = super(Ascii, self).validate(value) + if value: + charset = value if isinstance( + value, (bytearray, )) else map(ord, value) + if not set(range(128)).issuperset(charset): + raise ValidationError( + '{!r} is not an ASCII string.'.format(value)) + return value + + class Integer(Column): """ Stores a 32-bit signed integer value diff --git a/cassandra/cqlengine/management.py b/cassandra/cqlengine/management.py index 6978964ad0..cc2a34599f 100644 --- a/cassandra/cqlengine/management.py +++ b/cassandra/cqlengine/management.py @@ -21,7 +21,7 @@ from cassandra import metadata from cassandra.cqlengine import CQLEngineException -from cassandra.cqlengine import columns +from cassandra.cqlengine import columns, query from cassandra.cqlengine.connection import execute, get_cluster from cassandra.cqlengine.models import Model from cassandra.cqlengine.named import NamedTable @@ -119,10 +119,12 @@ def _get_index_name_by_column(table, column_name): return index_metadata.name -def sync_table(model): +def sync_table(model, keyspaces=None): """ Inspects the model and creates / updates the corresponding table and columns. + If `keyspaces` is specified, the table will be synched for all specified keyspaces. Note that the `Model.__keyspace__` is ignored in that case. + Any User Defined Types used in the table are implicitly synchronized. This function can only add fields that are not part of the primary key. @@ -135,6 +137,20 @@ def sync_table(model): *There are plans to guard schema-modifying functions with an environment-driven conditional.* """ + + if keyspaces: + if not isinstance(keyspaces, (list, tuple)): + raise ValueError('keyspaces must be a list or a tuple.') + + for keyspace in keyspaces: + with query.ContextQuery(model, keyspace=keyspace) as m: + _sync_table(m) + else: + _sync_table(model) + + +def _sync_table(model): + if not _allow_schema_modification(): return @@ -431,15 +447,29 @@ def _update_options(model): return False -def drop_table(model): +def drop_table(model, keyspaces=None): """ Drops the table indicated by the model, if it exists. + If `keyspaces` is specified, the table will be dropped for all specified keyspaces. Note that the `Model.__keyspace__` is ignored in that case. + **This function should be used with caution, especially in production environments. Take care to execute schema modifications in a single context (i.e. not concurrently with other clients).** *There are plans to guard schema-modifying functions with an environment-driven conditional.* """ + + if keyspaces: + if not isinstance(keyspaces, (list, tuple)): + raise ValueError('keyspaces must be a list or a tuple.') + + for keyspace in keyspaces: + with query.ContextQuery(model, keyspace=keyspace) as m: + _drop_table(m) + else: + _drop_table(model) + +def _drop_table(model): if not _allow_schema_modification(): return diff --git a/cassandra/cqlengine/models.py b/cassandra/cqlengine/models.py index e940955ed4..41dfc77770 100644 --- a/cassandra/cqlengine/models.py +++ b/cassandra/cqlengine/models.py @@ -352,7 +352,7 @@ class MultipleObjectsReturned(_MultipleObjectsReturned): _table_name = None # used internally to cache a derived table name def __init__(self, **values): - self._ttl = self.__default_ttl__ + self._ttl = None self._timestamp = None self._conditional = None self._batch = None @@ -361,7 +361,11 @@ def __init__(self, **values): self._values = {} for name, column in self._columns.items(): - value = values.get(name) + # Set default values on instantiation. Thanks to this, we don't have + # to wait anylonger for a call to validate() to have CQLengine set + # default columns values. + column_default = column.get_default() if column.has_default else None + value = values.get(name, column_default) if value is not None or isinstance(column, columns.BaseContainerColumn): value = column.to_python(value) value_mngr = column.value_manager(self, column, value) @@ -691,7 +695,6 @@ def save(self): self._set_persisted() - self._ttl = self.__default_ttl__ self._timestamp = None return self @@ -738,7 +741,6 @@ def update(self, **values): self._set_persisted() - self._ttl = self.__default_ttl__ self._timestamp = None return self @@ -794,17 +796,10 @@ def __new__(cls, name, bases, attrs): # short circuit __discriminator_value__ inheritance attrs['__discriminator_value__'] = attrs.get('__discriminator_value__') + # TODO __default__ttl__ should be removed in the next major release options = attrs.get('__options__') or {} attrs['__default_ttl__'] = options.get('default_time_to_live') - def _transform_column(col_name, col_obj): - column_dict[col_name] = col_obj - if col_obj.primary_key: - primary_keys[col_name] = col_obj - col_obj.set_column_name(col_name) - # set properties - attrs[col_name] = ColumnDescriptor(col_obj) - column_definitions = [(k, v) for k, v in attrs.items() if isinstance(v, columns.Column)] column_definitions = sorted(column_definitions, key=lambda x: x[1].position) @@ -849,6 +844,14 @@ def _get_polymorphic_base(bases): has_partition_keys = any(v.partition_key for (k, v) in column_definitions) + def _transform_column(col_name, col_obj): + column_dict[col_name] = col_obj + if col_obj.primary_key: + primary_keys[col_name] = col_obj + col_obj.set_column_name(col_name) + # set properties + attrs[col_name] = ColumnDescriptor(col_obj) + partition_key_index = 0 # transform column definitions for k, v in column_definitions: @@ -868,6 +871,12 @@ def _get_polymorphic_base(bases): if v.partition_key: v._partition_key_index = partition_key_index partition_key_index += 1 + + overriding = column_dict.get(k) + if overriding: + v.position = overriding.position + v.partition_key = overriding.partition_key + v._partition_key_index = overriding._partition_key_index _transform_column(k, v) partition_keys = OrderedDict(k for k in primary_keys.items() if k[1].partition_key) diff --git a/cassandra/cqlengine/query.py b/cassandra/cqlengine/query.py index 10d27ab580..e996baea3e 100644 --- a/cassandra/cqlengine/query.py +++ b/cassandra/cqlengine/query.py @@ -136,6 +136,8 @@ class BatchQuery(object): Handles the batching of queries http://www.datastax.com/docs/1.2/cql_cli/cql/BATCH + + See :doc:`/cqlengine/batches` for more details. """ warn_multiple_exec = True @@ -259,6 +261,46 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.execute() +class ContextQuery(object): + """ + A Context manager to allow a Model to switch context easily. Presently, the context only + specifies a keyspace for model IO. + + For example: + + .. code-block:: python + + with ContextQuery(Automobile, keyspace='test2') as A: + A.objects.create(manufacturer='honda', year=2008, model='civic') + print len(A.objects.all()) # 1 result + + with ContextQuery(Automobile, keyspace='test4') as A: + print len(A.objects.all()) # 0 result + + """ + + def __init__(self, model, keyspace=None): + """ + :param model: A model. This should be a class type, not an instance. + :param keyspace: (optional) A keyspace name + """ + from cassandra.cqlengine import models + + if not issubclass(model, models.Model): + raise CQLEngineException("Models must be derived from base Model.") + + ks = keyspace if keyspace else model.__keyspace__ + new_type = type(model.__name__, (model,), {'__keyspace__': ks}) + + self.model = new_type + + def __enter__(self): + return self.model + + def __exit__(self, exc_type, exc_val, exc_tb): + return + + class AbstractQuerySet(object): def __init__(self, model): @@ -299,7 +341,7 @@ def __init__(self, model): self._count = None self._batch = None - self._ttl = getattr(model, '__default_ttl__', None) + self._ttl = None self._consistency = None self._timestamp = None self._if_not_exists = False @@ -332,7 +374,7 @@ def __call__(self, *args, **kwargs): def __deepcopy__(self, memo): clone = self.__class__(self.model) for k, v in self.__dict__.items(): - if k in ['_con', '_cur', '_result_cache', '_result_idx', '_result_generator']: # don't clone these + if k in ['_con', '_cur', '_result_cache', '_result_idx', '_result_generator', '_construct_result']: # don't clone these, which are per-request-execution clone.__dict__[k] = None elif k == '_batch': # we need to keep the same batch instance across @@ -545,7 +587,7 @@ def _parse_filter_arg(self, arg): if len(statement) == 1: return arg, None elif len(statement) == 2: - return statement[0], statement[1] + return (statement[0], statement[1]) if arg != 'pk__token' else (arg, None) else: raise QueryException("Can't parse '{0}'".format(arg)) @@ -954,7 +996,8 @@ class ModelQuerySet(AbstractQuerySet): def _validate_select_where(self): """ Checks that a filterset will not create invalid select statement """ # check that there's either a =, a IN or a CONTAINS (collection) relationship with a primary key or indexed field - equal_ops = [self.model._get_column_by_db_name(w.field) for w in self._where if isinstance(w.operator, EqualsOperator)] + equal_ops = [self.model._get_column_by_db_name(w.field) \ + for w in self._where if isinstance(w.operator, EqualsOperator) and not isinstance(w.value, Token)] token_comparison = any([w for w in self._where if isinstance(w.value, Token)]) if not any(w.primary_key or w.index for w in equal_ops) and not token_comparison and not self._allow_filtering: raise QueryException(('Where clauses require either =, a IN or a CONTAINS (collection) ' @@ -971,6 +1014,9 @@ def _select_fields(self): fields = self.model._columns.keys() if self._defer_fields: fields = [f for f in fields if f not in self._defer_fields] + # select the partition keys if all model fields are set defer + if not fields: + fields = self.model._partition_keys if self._only_fields: fields = [f for f in fields if f in self._only_fields] if not fields: @@ -1154,6 +1200,7 @@ class Row(Model): return nulled_columns = set() + updated_columns = set() us = UpdateStatement(self.column_family_name, where=self._where, ttl=self._ttl, timestamp=self._timestamp, conditionals=self._conditional, if_exists=self._if_exists) for name, val in values.items(): @@ -1174,13 +1221,16 @@ class Row(Model): continue us.add_update(col, val, operation=col_op) + updated_columns.add(col_name) if us.assignments: self._execute(us) if nulled_columns: + delete_conditional = [condition for condition in self._conditional + if condition.field not in updated_columns] if self._conditional else None ds = DeleteStatement(self.column_family_name, fields=nulled_columns, - where=self._where, conditionals=self._conditional, if_exists=self._if_exists) + where=self._where, conditionals=delete_conditional, if_exists=self._if_exists) self._execute(ds) @@ -1227,11 +1277,11 @@ def batch(self, batch_obj): self._batch = batch_obj return self - def _delete_null_columns(self): + def _delete_null_columns(self, conditionals=None): """ executes a delete query to remove columns that have changed to null """ - ds = DeleteStatement(self.column_family_name, conditionals=self._conditional, if_exists=self._if_exists) + ds = DeleteStatement(self.column_family_name, conditionals=conditionals, if_exists=self._if_exists) deleted_fields = False for _, v in self.instance._values.items(): col = v.column @@ -1265,6 +1315,8 @@ def update(self): conditionals=self._conditional, if_exists=self._if_exists) for name, col in self.instance._clustering_keys.items(): null_clustering_key = null_clustering_key and col._val_is_null(getattr(self.instance, name, None)) + + updated_columns = set() # get defined fields and their column names for name, col in self.model._columns.items(): # if clustering key is null, don't include non static columns @@ -1282,6 +1334,7 @@ def update(self): static_changed_only = static_changed_only and col.static statement.add_update(col, val, previous=val_mgr.previous_value) + updated_columns.add(col.db_field_name) if statement.assignments: for name, col in self.model._primary_keys.items(): @@ -1292,7 +1345,10 @@ def update(self): self._execute(statement) if not null_clustering_key: - self._delete_null_columns() + # remove conditions on fields that have been updated + delete_conditionals = [condition for condition in self._conditional + if condition.field not in updated_columns] if self._conditional else None + self._delete_null_columns(delete_conditionals) def save(self): """ @@ -1341,7 +1397,7 @@ def delete(self): ds = DeleteStatement(self.column_family_name, timestamp=self._timestamp, conditionals=self._conditional, if_exists=self._if_exists) for name, col in self.model._primary_keys.items(): val = getattr(self.instance, name) - if val is None and not col.parition_key: + if val is None and not col.partition_key: continue ds.add_where(col, EqualsOperator(), val) self._execute(ds) diff --git a/cassandra/cqlengine/statements.py b/cassandra/cqlengine/statements.py index 3867704a77..44ae165e8b 100644 --- a/cassandra/cqlengine/statements.py +++ b/cassandra/cqlengine/statements.py @@ -35,9 +35,7 @@ def __init__(self, value): def __unicode__(self): from cassandra.encoder import cql_quote - if isinstance(self.value, bool): - return 'true' if self.value else 'false' - elif isinstance(self.value, (list, tuple)): + if isinstance(self.value, (list, tuple)): return '[' + ', '.join([cql_quote(v) for v in self.value]) + ']' elif isinstance(self.value, dict): return '{' + ', '.join([cql_quote(k) + ':' + cql_quote(v) for k, v in self.value.items()]) + '}' diff --git a/cassandra/cqltypes.py b/cassandra/cqltypes.py index 7eb0a2df58..b6a720e6c9 100644 --- a/cassandra/cqltypes.py +++ b/cassandra/cqltypes.py @@ -107,7 +107,7 @@ def __new__(metacls, name, bases, dct): cls = type.__new__(metacls, name, bases, dct) if not name.startswith('_'): _casstypes[name] = cls - if not cls.typename.startswith("'org"): + if not cls.typename.startswith(apache_cassandra_type_prefix): _cqltypes[cls.typename] = cls return cls @@ -682,6 +682,8 @@ class VarcharType(UTF8Type): class _ParameterizedType(_CassandraType): + num_subtypes = 'UNKNOWN' + @classmethod def deserialize(cls, byts, protocol_version): if not cls.subtypes: @@ -802,7 +804,6 @@ def serialize_safe(cls, themap, protocol_version): class TupleType(_ParameterizedType): typename = 'tuple' - num_subtypes = 'UNKNOWN' @classmethod def deserialize_safe(cls, byts, protocol_version): @@ -853,7 +854,7 @@ def cql_parameterized_type(cls): class UserType(TupleType): - typename = "'org.apache.cassandra.db.marshal.UserType'" + typename = "org.apache.cassandra.db.marshal.UserType" _cache = {} _module = sys.modules[__name__] @@ -956,8 +957,7 @@ def _make_udt_tuple_type(cls, name, field_names): class CompositeType(_ParameterizedType): - typename = "'org.apache.cassandra.db.marshal.CompositeType'" - num_subtypes = 'UNKNOWN' + typename = "org.apache.cassandra.db.marshal.CompositeType" @classmethod def cql_parameterized_type(cls): @@ -985,8 +985,13 @@ def deserialize_safe(cls, byts, protocol_version): return tuple(result) -class DynamicCompositeType(CompositeType): - typename = "'org.apache.cassandra.db.marshal.DynamicCompositeType'" +class DynamicCompositeType(_ParameterizedType): + typename = "org.apache.cassandra.db.marshal.DynamicCompositeType" + + @classmethod + def cql_parameterized_type(cls): + sublist = ', '.join('%s=>%s' % (alias, typ.cass_parameterized_type(full=True)) for alias, typ in zip(cls.fieldnames, cls.subtypes)) + return "'%s(%s)'" % (cls.typename, sublist) class ColumnToCollectionType(_ParameterizedType): @@ -995,12 +1000,11 @@ class ColumnToCollectionType(_ParameterizedType): Cassandra includes this. We don't actually need or want the extra information. """ - typename = "'org.apache.cassandra.db.marshal.ColumnToCollectionType'" - num_subtypes = 'UNKNOWN' + typename = "org.apache.cassandra.db.marshal.ColumnToCollectionType" class ReversedType(_ParameterizedType): - typename = "'org.apache.cassandra.db.marshal.ReversedType'" + typename = "org.apache.cassandra.db.marshal.ReversedType" num_subtypes = 1 @classmethod diff --git a/cassandra/encoder.py b/cassandra/encoder.py index 6d8b6ce8a2..98d562d1bc 100644 --- a/cassandra/encoder.py +++ b/cassandra/encoder.py @@ -40,8 +40,7 @@ def cql_quote(term): # The ordering of this method is important for the result of this method to # be a native str type (for both Python 2 and 3) - # Handle quoting of native str and bool types - if isinstance(term, (str, bool)): + if isinstance(term, str): return "'%s'" % str(term).replace("'", "''") # This branch of the if statement will only be used by Python 2 to catch # unicode strings, text_type is used to prevent type errors with Python 3. diff --git a/cassandra/io/eventletreactor.py b/cassandra/io/eventletreactor.py index dfaea8bfb4..cf1616d45b 100644 --- a/cassandra/io/eventletreactor.py +++ b/cassandra/io/eventletreactor.py @@ -16,13 +16,10 @@ # Originally derived from MagnetoDB source: # https://github.com/stackforge/magnetodb/blob/2015.1.0b1/magnetodb/common/cassandra/io/eventletreactor.py -from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, EINVAL import eventlet from eventlet.green import socket -import ssl from eventlet.queue import Queue import logging -import os from threading import Event import time @@ -34,15 +31,6 @@ log = logging.getLogger(__name__) -def is_timeout(err): - return ( - err in (EINPROGRESS, EALREADY, EWOULDBLOCK) or - (err == EINVAL and os.name in ('nt', 'ce')) or - (isinstance(err, ssl.SSLError) and err.args[0] == 'timed out') or - isinstance(err, socket.timeout) - ) - - class EventletConnection(Connection): """ An implementation of :class:`.Connection` that utilizes ``eventlet``. @@ -145,8 +133,6 @@ def handle_read(self): buf = self._socket.recv(self.in_buffer_size) self._iobuf.write(buf) except socket.error as err: - if is_timeout(err): - continue log.debug("Exception during socket recv for %s: %s", self, err) self.defunct(err) diff --git a/cassandra/io/geventreactor.py b/cassandra/io/geventreactor.py index 65572a664c..bf0a4cc181 100644 --- a/cassandra/io/geventreactor.py +++ b/cassandra/io/geventreactor.py @@ -18,26 +18,16 @@ import gevent.ssl import logging -import os import time from six.moves import range -from errno import EINVAL - from cassandra.connection import Connection, ConnectionShutdown, Timer, TimerManager log = logging.getLogger(__name__) -def is_timeout(err): - return ( - (err == EINVAL and os.name in ('nt', 'ce')) or - isinstance(err, socket.timeout) - ) - - class GeventConnection(Connection): """ An implementation of :class:`.Connection` that utilizes ``gevent``. @@ -131,11 +121,9 @@ def handle_read(self): buf = self._socket.recv(self.in_buffer_size) self._iobuf.write(buf) except socket.error as err: - if not is_timeout(err): - log.debug("Exception in read for %s: %s", self, err) - self.defunct(err) - return # leave the read loop - continue + log.debug("Exception in read for %s: %s", self, err) + self.defunct(err) + return # leave the read loop if self._iobuf.tell(): self.process_io_buffer() diff --git a/cassandra/io/libevreactor.py b/cassandra/io/libevreactor.py index a3e96a9a03..39f871a135 100644 --- a/cassandra/io/libevreactor.py +++ b/cassandra/io/libevreactor.py @@ -102,10 +102,10 @@ def maybe_start(self): def _run_loop(self): while True: - end_condition = self._loop.start() + self._loop.start() # there are still active watchers, no deadlock with self._lock: - if not self._shutdown and (end_condition or self._live_conns): + if not self._shutdown and self._live_conns: log.debug("Restarting event loop") continue else: @@ -121,10 +121,7 @@ def _cleanup(self): for conn in self._live_conns | self._new_conns | self._closed_conns: conn.close() - if conn._write_watcher: - conn._write_watcher.stop() - if conn._read_watcher: - conn._read_watcher.stop() + map(lambda w: w.stop(), (w for w in (conn._write_watcher, conn._read_watcher) if w)) self.notify() # wake the timer watcher log.debug("Waiting for event loop thread to join...") @@ -135,7 +132,6 @@ def _cleanup(self): "Please call Cluster.shutdown() to avoid this.") log.debug("Event loop thread was joined") - self._loop = None def add_timer(self, timer): self._timers.add_timer(timer) diff --git a/cassandra/metadata.py b/cassandra/metadata.py index 1cd801eed2..dedaa2de7b 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -1058,14 +1058,11 @@ def is_cql_compatible(self): """ comparator = getattr(self, 'comparator', None) if comparator: - # no such thing as DCT in CQL - incompatible = issubclass(self.comparator, types.DynamicCompositeType) - # no compact storage with more than one column beyond PK if there # are clustering columns - incompatible |= (self.is_compact_storage and - len(self.columns) > len(self.primary_key) + 1 and - len(self.clustering_key) >= 1) + incompatible = (self.is_compact_storage and + len(self.columns) > len(self.primary_key) + 1 and + len(self.clustering_key) >= 1) return not incompatible return True @@ -1777,12 +1774,9 @@ def _build_table_metadata(self, row, col_rows=None, trigger_rows=None): comparator = types.lookup_casstype(row["comparator"]) table_meta.comparator = comparator - if issubclass(comparator, types.CompositeType): - column_name_types = comparator.subtypes - is_composite_comparator = True - else: - column_name_types = (comparator,) - is_composite_comparator = False + is_dct_comparator = issubclass(comparator, types.DynamicCompositeType) + is_composite_comparator = issubclass(comparator, types.CompositeType) + column_name_types = comparator.subtypes if is_composite_comparator else (comparator,) num_column_name_components = len(column_name_types) last_col = column_name_types[-1] @@ -1796,7 +1790,8 @@ def _build_table_metadata(self, row, col_rows=None, trigger_rows=None): if column_aliases is not None: column_aliases = json.loads(column_aliases) - else: + + if not column_aliases: # json load failed or column_aliases empty PYTHON-562 column_aliases = [r.get('column_name') for r in clustering_rows] if is_composite_comparator: @@ -1819,10 +1814,10 @@ def _build_table_metadata(self, row, col_rows=None, trigger_rows=None): # Some thrift tables define names in composite types (see PYTHON-192) if not column_aliases and hasattr(comparator, 'fieldnames'): - column_aliases = comparator.fieldnames + column_aliases = filter(None, comparator.fieldnames) else: is_compact = True - if column_aliases or not col_rows: + if column_aliases or not col_rows or is_dct_comparator: has_value = True clustering_size = num_column_name_components else: @@ -1867,7 +1862,7 @@ def _build_table_metadata(self, row, col_rows=None, trigger_rows=None): if len(column_aliases) > i: column_name = column_aliases[i] else: - column_name = "column%d" % i + column_name = "column%d" % (i + 1) data_type = column_name_types[i] cql_type = _cql_from_cass_type(data_type) diff --git a/cassandra/metrics.py b/cassandra/metrics.py index cf1f25c15d..d0c5b9e39c 100644 --- a/cassandra/metrics.py +++ b/cassandra/metrics.py @@ -111,10 +111,14 @@ class Metrics(object): the driver currently has open. """ + _stats_counter = 0 + def __init__(self, cluster_proxy): log.debug("Starting metric capture") - self.stats = scales.collection('/cassandra', + self.stats_name = 'cassandra-{0}'.format(str(self._stats_counter)) + Metrics._stats_counter += 1 + self.stats = scales.collection(self.stats_name, scales.PmfStat('request_timer'), scales.IntStat('connection_errors'), scales.IntStat('write_timeouts'), @@ -132,6 +136,11 @@ def __init__(self, cluster_proxy): scales.Stat('open_connections', lambda: sum(sum(p.open_count for p in s._pools.values()) for s in cluster_proxy.sessions))) + # TODO, to be removed in 4.0 + # /cassandra contains the metrics of the first cluster registered + if 'cassandra' not in scales._Stats.stats: + scales._Stats.stats['cassandra'] = scales._Stats.stats[self.stats_name] + self.request_timer = self.stats.request_timer self.connection_errors = self.stats.connection_errors self.write_timeouts = self.stats.write_timeouts @@ -164,3 +173,27 @@ def on_ignore(self): def on_retry(self): self.stats.retries += 1 + + def get_stats(self): + """ + Returns the metrics for the registered cluster instance. + """ + return scales.getStats()[self.stats_name] + + def set_stats_name(self, stats_name): + """ + Set the metrics stats name. + The stats_name is a string used to access the metris through scales: scales.getStats()[] + Default is 'cassandra-'. + """ + + if self.stats_name == stats_name: + return + + if stats_name in scales._Stats.stats: + raise ValueError('"{0}" already exists in stats.'.format(stats_name)) + + stats = scales._Stats.stats[self.stats_name] + del scales._Stats.stats[self.stats_name] + self.stats_name = stats_name + scales._Stats.stats[self.stats_name] = stats diff --git a/cassandra/numpy_parser.pyx b/cassandra/numpy_parser.pyx index 1334e747c4..ed755d00a4 100644 --- a/cassandra/numpy_parser.pyx +++ b/cassandra/numpy_parser.pyx @@ -13,7 +13,7 @@ # limitations under the License. """ -This module provider an optional protocol parser that returns +This module provides an optional protocol parser that returns NumPy arrays. ============================================================================= @@ -25,7 +25,7 @@ as numpy is an optional dependency. include "ioutils.pyx" cimport cython -from libc.stdint cimport uint64_t +from libc.stdint cimport uint64_t, uint8_t from cpython.ref cimport Py_INCREF, PyObject from cassandra.bytesio cimport BytesIOReader @@ -35,7 +35,6 @@ from cassandra import cqltypes from cassandra.util import is_little_endian import numpy as np -# import pandas as pd cdef extern from "numpyFlags.h": # Include 'numpyFlags.h' into the generated C code to disable the @@ -52,11 +51,13 @@ ctypedef struct ArrDesc: Py_uintptr_t buf_ptr int stride # should be large enough as we allocate contiguous arrays int is_object + Py_uintptr_t mask_ptr arrDescDtype = np.dtype( [ ('buf_ptr', np.uintp) , ('stride', np.dtype('i')) , ('is_object', np.dtype('i')) + , ('mask_ptr', np.uintp) ], align=True) _cqltype_to_numpy = { @@ -70,6 +71,7 @@ _cqltype_to_numpy = { obj_dtype = np.dtype('O') +cdef uint8_t mask_true = 0x01 cdef class NumpyParser(ColumnParser): """Decode a ResultMessage into a bunch of NumPy arrays""" @@ -116,7 +118,11 @@ def make_arrays(ParseDesc desc, array_size): arr = make_array(coltype, array_size) array_descs[i]['buf_ptr'] = arr.ctypes.data array_descs[i]['stride'] = arr.strides[0] - array_descs[i]['is_object'] = coltype not in _cqltype_to_numpy + array_descs[i]['is_object'] = arr.dtype is obj_dtype + try: + array_descs[i]['mask_ptr'] = arr.mask.ctypes.data + except AttributeError: + array_descs[i]['mask_ptr'] = 0 arrays.append(arr) return array_descs, arrays @@ -126,8 +132,12 @@ def make_array(coltype, array_size): """ Allocate a new NumPy array of the given column type and size. """ - dtype = _cqltype_to_numpy.get(coltype, obj_dtype) - return np.empty((array_size,), dtype=dtype) + try: + a = np.ma.empty((array_size,), dtype=_cqltype_to_numpy[coltype]) + a.mask = np.zeros((array_size,), dtype=np.bool) + except KeyError: + a = np.empty((array_size,), dtype=obj_dtype) + return a #### Parse rows into NumPy arrays @@ -140,7 +150,6 @@ cdef inline int unpack_row( cdef Py_ssize_t i, rowsize = desc.rowsize cdef ArrDesc arr cdef Deserializer deserializer - for i in range(rowsize): get_buf(reader, &buf) arr = arrays[i] @@ -150,13 +159,14 @@ cdef inline int unpack_row( val = from_binary(deserializer, &buf, desc.protocol_version) Py_INCREF(val) ( arr.buf_ptr)[0] = val - elif buf.size < 0: - raise ValueError("Cannot handle NULL value") - else: + elif buf.size >= 0: memcpy( arr.buf_ptr, buf.ptr, buf.size) + else: + memcpy(arr.mask_ptr, &mask_true, 1) # Update the pointer into the array for the next time arrays[i].buf_ptr += arr.stride + arrays[i].mask_ptr += 1 return 0 diff --git a/cassandra/protocol.py b/cassandra/protocol.py index 4c63d557d5..e9e4450f5a 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -126,7 +126,7 @@ def __init__(self, code, message, info): self.info = info @classmethod - def recv_body(cls, f, protocol_version, user_type_map): + def recv_body(cls, f, *args): code = read_int(f) msg = read_string(f) subcls = error_classes.get(code, cls) @@ -378,7 +378,7 @@ class ReadyMessage(_MessageType): name = 'READY' @classmethod - def recv_body(cls, f, protocol_version, user_type_map): + def recv_body(cls, *args): return cls() @@ -390,7 +390,7 @@ def __init__(self, authenticator): self.authenticator = authenticator @classmethod - def recv_body(cls, f, protocol_version, user_type_map): + def recv_body(cls, f, *args): authname = read_string(f) return cls(authenticator=authname) @@ -422,7 +422,7 @@ def __init__(self, challenge): self.challenge = challenge @classmethod - def recv_body(cls, f, protocol_version, user_type_map): + def recv_body(cls, f, *args): return cls(read_binary_longstring(f)) @@ -445,7 +445,7 @@ def __init__(self, token): self.token = token @classmethod - def recv_body(cls, f, protocol_version, user_type_map): + def recv_body(cls, f, *args): return cls(read_longstring(f)) @@ -466,7 +466,7 @@ def __init__(self, cql_versions, options): self.options = options @classmethod - def recv_body(cls, f, protocol_version, user_type_map): + def recv_body(cls, f, *args): options = read_stringmultimap(f) cql_versions = options.pop('CQL_VERSION') return cls(cql_versions=cql_versions, options=options) @@ -474,7 +474,7 @@ def recv_body(cls, f, protocol_version, user_type_map): # used for QueryMessage and ExecuteMessage _VALUES_FLAG = 0x01 -_SKIP_METADATA_FLAG = 0x01 +_SKIP_METADATA_FLAG = 0x02 _PAGE_SIZE_FLAG = 0x04 _WITH_PAGING_STATE_FLAG = 0x08 _WITH_SERIAL_CONSISTENCY_FLAG = 0x10 @@ -577,14 +577,14 @@ def __init__(self, kind, results, paging_state=None): self.paging_state = paging_state @classmethod - def recv_body(cls, f, protocol_version, user_type_map): + def recv_body(cls, f, protocol_version, user_type_map, result_metadata): kind = read_int(f) paging_state = None if kind == RESULT_KIND_VOID: results = None elif kind == RESULT_KIND_ROWS: paging_state, results = cls.recv_results_rows( - f, protocol_version, user_type_map) + f, protocol_version, user_type_map, result_metadata) elif kind == RESULT_KIND_SET_KEYSPACE: ksname = read_string(f) results = ksname @@ -597,8 +597,9 @@ def recv_body(cls, f, protocol_version, user_type_map): return cls(kind, results, paging_state) @classmethod - def recv_results_rows(cls, f, protocol_version, user_type_map): + def recv_results_rows(cls, f, protocol_version, user_type_map, result_metadata): paging_state, column_metadata = cls.recv_results_metadata(f, user_type_map) + column_metadata = column_metadata or result_metadata rowcount = read_int(f) rows = [cls.recv_row(f, len(column_metadata)) for _ in range(rowcount)] colnames = [c[2] for c in column_metadata] @@ -607,24 +608,29 @@ def recv_results_rows(cls, f, protocol_version, user_type_map): tuple(ctype.from_binary(val, protocol_version) for ctype, val in zip(coltypes, row)) for row in rows] - return (paging_state, (colnames, parsed_rows)) + return paging_state, (colnames, parsed_rows) @classmethod def recv_results_prepared(cls, f, protocol_version, user_type_map): query_id = read_binary_string(f) - column_metadata, pk_indexes = cls.recv_prepared_metadata(f, protocol_version, user_type_map) - return (query_id, column_metadata, pk_indexes) + bind_metadata, pk_indexes, result_metadata = cls.recv_prepared_metadata(f, protocol_version, user_type_map) + return query_id, bind_metadata, pk_indexes, result_metadata @classmethod def recv_results_metadata(cls, f, user_type_map): flags = read_int(f) - glob_tblspec = bool(flags & cls._FLAGS_GLOBAL_TABLES_SPEC) colcount = read_int(f) if flags & cls._HAS_MORE_PAGES_FLAG: paging_state = read_binary_longstring(f) else: paging_state = None + + no_meta = bool(flags & cls._NO_METADATA_FLAG) + if no_meta: + return paging_state, [] + + glob_tblspec = bool(flags & cls._FLAGS_GLOBAL_TABLES_SPEC) if glob_tblspec: ksname = read_string(f) cfname = read_string(f) @@ -644,17 +650,17 @@ def recv_results_metadata(cls, f, user_type_map): @classmethod def recv_prepared_metadata(cls, f, protocol_version, user_type_map): flags = read_int(f) - glob_tblspec = bool(flags & cls._FLAGS_GLOBAL_TABLES_SPEC) colcount = read_int(f) pk_indexes = None if protocol_version >= 4: num_pk_indexes = read_int(f) pk_indexes = [read_short(f) for _ in range(num_pk_indexes)] + glob_tblspec = bool(flags & cls._FLAGS_GLOBAL_TABLES_SPEC) if glob_tblspec: ksname = read_string(f) cfname = read_string(f) - column_metadata = [] + bind_metadata = [] for _ in range(colcount): if glob_tblspec: colksname = ksname @@ -664,8 +670,13 @@ def recv_prepared_metadata(cls, f, protocol_version, user_type_map): colcfname = read_string(f) colname = read_string(f) coltype = cls.read_type(f, user_type_map) - column_metadata.append(ColumnMetadata(colksname, colcfname, colname, coltype)) - return column_metadata, pk_indexes + bind_metadata.append(ColumnMetadata(colksname, colcfname, colname, coltype)) + + if protocol_version >= 2: + _, result_metadata = cls.recv_results_metadata(f, user_type_map) + return bind_metadata, pk_indexes, result_metadata + else: + return bind_metadata, pk_indexes, None @classmethod def recv_results_schema_change(cls, f, protocol_version): @@ -727,7 +738,7 @@ class ExecuteMessage(_MessageType): def __init__(self, query_id, query_params, consistency_level, serial_consistency_level=None, fetch_size=None, - paging_state=None, timestamp=None): + paging_state=None, timestamp=None, skip_meta=False): self.query_id = query_id self.query_params = query_params self.consistency_level = consistency_level @@ -735,6 +746,7 @@ def __init__(self, query_id, query_params, consistency_level, self.fetch_size = fetch_size self.paging_state = paging_state self.timestamp = timestamp + self.skip_meta = skip_meta def send_body(self, f, protocol_version): write_string(f, self.query_id) @@ -768,6 +780,8 @@ def send_body(self, f, protocol_version): raise UnsupportedOperation( "Protocol-level timestamps may only be used with protocol version " "3 or higher. Consider setting Cluster.protocol_version to 3.") + if self.skip_meta: + flags |= _SKIP_METADATA_FLAG write_byte(f, flags) write_short(f, len(self.query_params)) for param in self.query_params: @@ -782,6 +796,7 @@ def send_body(self, f, protocol_version): write_long(f, self.timestamp) + class BatchMessage(_MessageType): opcode = 0x0D name = 'BATCH' @@ -851,7 +866,7 @@ def __init__(self, event_type, event_args): self.event_args = event_args @classmethod - def recv_body(cls, f, protocol_version, user_type_map): + def recv_body(cls, f, protocol_version, *args): event_type = read_string(f).upper() if event_type in known_event_types: read_method = getattr(cls, 'recv_' + event_type.lower()) @@ -960,7 +975,7 @@ def _write_header(f, version, flags, stream_id, opcode, length): @classmethod def decode_message(cls, protocol_version, user_type_map, stream_id, flags, opcode, body, - decompressor): + decompressor, result_metadata): """ Decodes a native protocol message body @@ -1002,7 +1017,7 @@ def decode_message(cls, protocol_version, user_type_map, stream_id, flags, opcod log.warning("Unknown protocol flags set: %02x. May cause problems.", flags) msg_class = cls.message_types_by_opcode[opcode] - msg = msg_class.recv_body(body, protocol_version, user_type_map) + msg = msg_class.recv_body(body, protocol_version, user_type_map, result_metadata) msg.stream_id = stream_id msg.trace_id = trace_id msg.custom_payload = custom_payload diff --git a/cassandra/query.py b/cassandra/query.py index 8662f0bda4..65cb6ba9e0 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -219,8 +219,7 @@ class Statement(object): _routing_key = None def __init__(self, retry_policy=None, consistency_level=None, routing_key=None, - serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, - custom_payload=None): + serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, custom_payload=None): if retry_policy and not hasattr(retry_policy, 'on_read_timeout'): # just checking one method to detect positional parameter errors raise ValueError('retry_policy should implement cassandra.policies.RetryPolicy') self.retry_policy = retry_policy @@ -362,36 +361,34 @@ class PreparedStatement(object): may affect performance (as the operation requires a network roundtrip). """ - column_metadata = None + column_metadata = None #TODO: make this bind_metadata in next major + consistency_level = None + custom_payload = None + fetch_size = FETCH_SIZE_UNSET + keyspace = None # change to prepared_keyspace in major release + protocol_version = None query_id = None query_string = None - keyspace = None # change to prepared_keyspace in major release - + result_metadata = None routing_key_indexes = None _routing_key_index_set = None - - consistency_level = None serial_consistency_level = None - protocol_version = None - - fetch_size = FETCH_SIZE_UNSET - - custom_payload = None - def __init__(self, column_metadata, query_id, routing_key_indexes, query, - keyspace, protocol_version): + keyspace, protocol_version, result_metadata): self.column_metadata = column_metadata self.query_id = query_id self.routing_key_indexes = routing_key_indexes self.query_string = query self.keyspace = keyspace self.protocol_version = protocol_version + self.result_metadata = result_metadata @classmethod - def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata, query, prepared_keyspace, protocol_version): + def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata, + query, prepared_keyspace, protocol_version, result_metadata): if not column_metadata: - return PreparedStatement(column_metadata, query_id, None, query, prepared_keyspace, protocol_version) + return PreparedStatement(column_metadata, query_id, None, query, prepared_keyspace, protocol_version, result_metadata) if pk_indexes: routing_key_indexes = pk_indexes @@ -416,7 +413,7 @@ def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata, q pass # statement; just leave routing_key_indexes as None return PreparedStatement(column_metadata, query_id, routing_key_indexes, - query, prepared_keyspace, protocol_version) + query, prepared_keyspace, protocol_version, result_metadata) def bind(self, values): """ diff --git a/cassandra/row_parser.pyx b/cassandra/row_parser.pyx index ec2b83bed7..8422d544d3 100644 --- a/cassandra/row_parser.pyx +++ b/cassandra/row_parser.pyx @@ -18,13 +18,15 @@ from cassandra.deserializers import make_deserializers include "ioutils.pyx" def make_recv_results_rows(ColumnParser colparser): - def recv_results_rows(cls, f, int protocol_version, user_type_map): + def recv_results_rows(cls, f, int protocol_version, user_type_map, result_metadata): """ Parse protocol data given as a BytesIO f into a set of columns (e.g. list of tuples) This is used as the recv_results_rows method of (Fast)ResultMessage """ paging_state, column_metadata = cls.recv_results_metadata(f, user_type_map) + column_metadata = column_metadata or result_metadata + colnames = [c[2] for c in column_metadata] coltypes = [c[3] for c in column_metadata] diff --git a/cassandra/type_codes.py b/cassandra/type_codes.py index 2f0ce8f5a0..daf882e46c 100644 --- a/cassandra/type_codes.py +++ b/cassandra/type_codes.py @@ -59,4 +59,3 @@ SetType = 0x0022 UserType = 0x0030 TupleType = 0x0031 - diff --git a/cassandra/util.py b/cassandra/util.py index f4bc1b1c94..7f17e85d18 100644 --- a/cassandra/util.py +++ b/cassandra/util.py @@ -973,6 +973,8 @@ def __eq__(self, other): microsecond=self.nanosecond // Time.MICRO) == other def __lt__(self, other): + if not isinstance(other, Time): + return NotImplemented return self.nanosecond_time < other.nanosecond_time def __repr__(self): @@ -1061,6 +1063,8 @@ def __eq__(self, other): return False def __lt__(self, other): + if not isinstance(other, Date): + return NotImplemented return self.days_from_epoch < other.days_from_epoch def __repr__(self): diff --git a/docs.yaml b/docs.yaml index aa30ed5df3..b337d5dd7b 100644 --- a/docs.yaml +++ b/docs.yaml @@ -6,3 +6,8 @@ sections: prefix: / type: sphinx directory: docs +versions: + - name: 3.5.0 + ref: 3.5.0 +redirects: + - \A\/(.*)/\Z: /\1.html diff --git a/docs/api/cassandra/cqlengine/models.rst b/docs/api/cassandra/cqlengine/models.rst index d6f3391974..fd081fb190 100644 --- a/docs/api/cassandra/cqlengine/models.rst +++ b/docs/api/cassandra/cqlengine/models.rst @@ -32,8 +32,10 @@ Model .. autoattribute:: __keyspace__ - .. _ttl-change: - .. autoattribute:: __default_ttl__ + .. attribute:: __default_ttl__ + :annotation: = None + + Will be deprecated in release 4.0. You can set the default ttl by configuring the table ``__options__``. See :ref:`ttl-change` for more details. .. autoattribute:: __discriminator_value__ diff --git a/docs/api/cassandra/cqlengine/query.rst b/docs/api/cassandra/cqlengine/query.rst index 461ec9b969..c0c8f285cf 100644 --- a/docs/api/cassandra/cqlengine/query.rst +++ b/docs/api/cassandra/cqlengine/query.rst @@ -54,6 +54,14 @@ The methods here are used to filter, order, and constrain results. .. automethod:: update +.. autoclass:: BatchQuery + :members: + + .. automethod:: add_query + .. automethod:: execute + +.. autoclass:: ContextQuery + .. autoclass:: DoesNotExist .. autoclass:: MultipleObjectsReturned diff --git a/docs/cqlengine/queryset.rst b/docs/cqlengine/queryset.rst index ff328b0ce4..c9c33932f8 100644 --- a/docs/cqlengine/queryset.rst +++ b/docs/cqlengine/queryset.rst @@ -343,6 +343,42 @@ None means no timeout. Setting the timeout on the model is meaningless and will raise an AssertionError. +.. _ttl-change: + +Default TTL and Per Query TTL +============================= + +Model default TTL now relies on the *default_time_to_live* feature, introduced in Cassandra 2.0. It is not handled anymore in the CQLEngine Model (cassandra-driver >=3.6). You can set the default TTL of a table like this: + + Example: + + .. code-block:: python + + class User(Model): + __options__ = {'default_time_to_live': 20} + + user_id = columns.UUID(primary_key=True) + ... + +You can set TTL per-query if needed. Here are a some examples: + + Example: + + .. code-block:: python + + class User(Model): + __options__ = {'default_time_to_live': 20} + + user_id = columns.UUID(primary_key=True) + ... + + user = User.objects.create(user_id=1) # Default TTL 20 will be set automatically on the server + + user.ttl(30).update(age=21) # Update the TTL to 30 + User.objects.ttl(10).create(user_id=1) # TTL 10 + User(user_id=1, age=21).ttl(10).save() # TTL 10 + + Named Tables =================== diff --git a/docs/getting_started.rst b/docs/getting_started.rst index 2d9c7ea461..c7cbc25970 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -179,7 +179,7 @@ Named place-holders use the ``%(name)s`` form: """ INSERT INTO users (name, credits, user_id, username) VALUES (%(name)s, %(credits)s, %(user_id)s, %(name)s) - """ + """, {'name': "John O'Reilly", 'credits': 42, 'user_id': uuid.uuid1()} ) diff --git a/test-requirements.txt b/test-requirements.txt index 4c917da6c6..500795357c 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,7 +1,7 @@ -r requirements.txt scales nose -mock<=1.0.1 +mock!=1.1.* ccm>=2.0 unittest2 PyYAML diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py index 62a58896a4..bd9fe103cd 100644 --- a/tests/integration/__init__.py +++ b/tests/integration/__init__.py @@ -23,6 +23,7 @@ import sys import time import traceback +import platform from threading import Event from subprocess import call from itertools import groupby @@ -137,14 +138,67 @@ def _get_cass_version_from_dse(dse_version): CCM_KWARGS['dse_credentials_file'] = DSE_CRED -if CASSANDRA_VERSION >= '2.2': - default_protocol_version = 4 -elif CASSANDRA_VERSION >= '2.1': - default_protocol_version = 3 -elif CASSANDRA_VERSION >= '2.0': - default_protocol_version = 2 -else: - default_protocol_version = 1 +def get_default_protocol(): + + if CASSANDRA_VERSION >= '2.2': + return 4 + elif CASSANDRA_VERSION >= '2.1': + return 3 + elif CASSANDRA_VERSION >= '2.0': + return 2 + else: + return 1 + + +def get_supported_protocol_versions(): + """ + 1.2 -> 1 + 2.0 -> 2, 1 + 2.1 -> 3, 2, 1 + 2.2 -> 4, 3, 2, 1 + 3.X -> 4, 3 +` """ + if CASSANDRA_VERSION >= '3.0': + return (3, 4) + elif CASSANDRA_VERSION >= '2.2': + return (1, 2, 3, 4) + elif CASSANDRA_VERSION >= '2.1': + return (1, 2, 3) + elif CASSANDRA_VERSION >= '2.0': + return (1, 2) + else: + return (1) + + +def get_unsupported_lower_protocol(): + """ + This is used to determine the lowest protocol version that is NOT + supported by the version of C* running + """ + + if CASSANDRA_VERSION >= '3.0': + return 2 + else: + return None + + +def get_unsupported_upper_protocol(): + """ + This is used to determine the highest protocol version that is NOT + supported by the version of C* running + """ + + if CASSANDRA_VERSION >= '2.2': + return None + if CASSANDRA_VERSION >= '2.1': + return 4 + elif CASSANDRA_VERSION >= '2.0': + return 3 + else: + return None + +default_protocol_version = get_default_protocol() + PROTOCOL_VERSION = int(os.getenv('PROTOCOL_VERSION', default_protocol_version)) @@ -157,6 +211,7 @@ def _get_cass_version_from_dse(dse_version): greaterthanorequalcass30 = unittest.skipUnless(CASSANDRA_VERSION >= '3.0', 'Cassandra version 3.0 or greater required') lessthancass30 = unittest.skipUnless(CASSANDRA_VERSION < '3.0', 'Cassandra version less then 3.0 required') dseonly = unittest.skipUnless(DSE_VERSION, "Test is only applicalbe to DSE clusters") +pypy = unittest.skipUnless(platform.python_implementation() == "PyPy", "Test is skipped unless it's on PyPy") def wait_for_node_socket(node, timeout): @@ -241,6 +296,7 @@ def use_cluster(cluster_name, nodes, ipformat=None, start=True, workloads=[]): log.debug("Using external CCM cluster {0}".format(CCM_CLUSTER.name)) else: log.debug("Using unnamed external cluster") + setup_keyspace(ipformat=ipformat, wait=False) return if is_current_cluster(cluster_name, nodes): @@ -387,9 +443,10 @@ def drop_keyspace_shutdown_cluster(keyspace_name, session, cluster): cluster.shutdown() -def setup_keyspace(ipformat=None): +def setup_keyspace(ipformat=None, wait=True): # wait for nodes to startup - time.sleep(10) + if wait: + time.sleep(10) if not ipformat: cluster = Cluster(protocol_version=PROTOCOL_VERSION) @@ -481,8 +538,8 @@ def create_keyspace(cls, rf): execute_with_long_wait_retry(cls.session, ddl) @classmethod - def common_setup(cls, rf, keyspace_creation=True, create_class_table=False): - cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION) + def common_setup(cls, rf, keyspace_creation=True, create_class_table=False, metrics=False): + cls.cluster = Cluster(protocol_version=PROTOCOL_VERSION, metrics_enabled=metrics) cls.session = cls.cluster.connect() cls.ks_name = cls.__name__.lower() if keyspace_creation: @@ -535,6 +592,7 @@ def get_message_count(self, level, sub_string): count+=1 return count + class BasicExistingKeyspaceUnitTestCase(BasicKeyspaceUnitTestCase): """ This is basic unit test defines class level teardown and setup methods. It assumes that keyspace is already defined, or created as part of the test. @@ -589,7 +647,7 @@ class BasicSharedKeyspaceUnitTestCaseWTable(BasicSharedKeyspaceUnitTestCase): """ @classmethod def setUpClass(self): - self.common_setup(2, True) + self.common_setup(3, True, True, True) class BasicSharedKeyspaceUnitTestCaseRF3(BasicSharedKeyspaceUnitTestCase): diff --git a/tests/integration/cqlengine/__init__.py b/tests/integration/cqlengine/__init__.py index e61698e82c..3f163ded64 100644 --- a/tests/integration/cqlengine/__init__.py +++ b/tests/integration/cqlengine/__init__.py @@ -96,7 +96,7 @@ def wrapped_function(*args, **kwargs): else: test_case = args[0] # Check to see if the count is what you expect - test_case.assertEqual(count.get_counter(), expected, msg="Expected number of cassandra.cqlengine.connection.execute calls doesn't match actual number invoked Expected: {0}, Invoked {1}".format(count.get_counter(), expected)) + test_case.assertEqual(count.get_counter(), expected, msg="Expected number of cassandra.cqlengine.connection.execute calls ({0}) doesn't match actual number invoked ({1})".format(expected, count.get_counter())) return to_return # Name of the wrapped function must match the original or unittest will error out. wrapped_function.__name__ = fn.__name__ diff --git a/tests/integration/cqlengine/columns/test_validation.py b/tests/integration/cqlengine/columns/test_validation.py index 0480fe43e8..4980415208 100644 --- a/tests/integration/cqlengine/columns/test_validation.py +++ b/tests/integration/cqlengine/columns/test_validation.py @@ -17,12 +17,14 @@ except ImportError: import unittest # noqa +import sys from datetime import datetime, timedelta, date, tzinfo from decimal import Decimal as D from uuid import uuid4, uuid1 from cassandra import InvalidRequest from cassandra.cqlengine.columns import TimeUUID +from cassandra.cqlengine.columns import Ascii from cassandra.cqlengine.columns import Text from cassandra.cqlengine.columns import Integer from cassandra.cqlengine.columns import BigInt @@ -337,50 +339,249 @@ def test_default_zero_fields_validate(self): it.validate() -class TestText(BaseCassEngTestCase): +class TestAscii(BaseCassEngTestCase): def test_min_length(self): - # not required defaults to 0 - col = Text() - col.validate('') - col.validate('b') + """ Test arbitrary minimal lengths requirements. """ + Ascii(min_length=0).validate('') + Ascii(min_length=0).validate(None) + Ascii(min_length=0).validate('kevin') + + Ascii(min_length=1).validate('k') + + Ascii(min_length=5).validate('kevin') + Ascii(min_length=5).validate('kevintastic') - # required defaults to 1 with self.assertRaises(ValidationError): - Text(required=True).validate('') + Ascii(min_length=1).validate('') + + with self.assertRaises(ValidationError): + Ascii(min_length=1).validate(None) + + with self.assertRaises(ValidationError): + Ascii(min_length=6).validate('') + + with self.assertRaises(ValidationError): + Ascii(min_length=6).validate(None) + + with self.assertRaises(ValidationError): + Ascii(min_length=6).validate('kevin') + + with self.assertRaises(ValueError): + Ascii(min_length=-1) + + def test_max_length(self): + """ Test arbitrary maximal lengths requirements. """ + Ascii(max_length=0).validate('') + Ascii(max_length=0).validate(None) + + Ascii(max_length=1).validate('') + Ascii(max_length=1).validate(None) + Ascii(max_length=1).validate('b') + + Ascii(max_length=5).validate('') + Ascii(max_length=5).validate(None) + Ascii(max_length=5).validate('b') + Ascii(max_length=5).validate('blake') + + with self.assertRaises(ValidationError): + Ascii(max_length=0).validate('b') + + with self.assertRaises(ValidationError): + Ascii(max_length=5).validate('blaketastic') + + with self.assertRaises(ValueError): + Ascii(max_length=-1) + + def test_length_range(self): + Ascii(min_length=0, max_length=0) + Ascii(min_length=0, max_length=1) + Ascii(min_length=10, max_length=10) + Ascii(min_length=10, max_length=11) + + with self.assertRaises(ValueError): + Ascii(min_length=10, max_length=9) + + with self.assertRaises(ValueError): + Ascii(min_length=1, max_length=0) + + def test_type_checking(self): + Ascii().validate('string') + Ascii().validate(u'unicode') + Ascii().validate(bytearray('bytearray', encoding='ascii')) + + with self.assertRaises(ValidationError): + Ascii().validate(5) + + with self.assertRaises(ValidationError): + Ascii().validate(True) + + Ascii().validate("!#$%&\'()*+,-./") + + with self.assertRaises(ValidationError): + Ascii().validate('Beyonc' + chr(233)) + + if sys.version_info < (3, 1): + with self.assertRaises(ValidationError): + Ascii().validate(u'Beyonc' + unichr(233)) + + def test_unaltering_validation(self): + """ Test the validation step doesn't re-interpret values. """ + self.assertEqual(Ascii().validate(''), '') + self.assertEqual(Ascii().validate(None), None) + self.assertEqual(Ascii().validate('yo'), 'yo') + + def test_non_required_validation(self): + """ Tests that validation is ok on none and blank values if required is False. """ + Ascii().validate('') + Ascii().validate(None) + + def test_required_validation(self): + """ Tests that validation raise on none and blank values if value required. """ + Ascii(required=True).validate('k') + + with self.assertRaises(ValidationError): + Ascii(required=True).validate('') + + with self.assertRaises(ValidationError): + Ascii(required=True).validate(None) + + # With min_length set. + Ascii(required=True, min_length=0).validate('k') + Ascii(required=True, min_length=1).validate('k') + + with self.assertRaises(ValidationError): + Ascii(required=True, min_length=2).validate('k') + + # With max_length set. + Ascii(required=True, max_length=1).validate('k') - #test arbitrary lengths + with self.assertRaises(ValidationError): + Ascii(required=True, max_length=2).validate('kevin') + + with self.assertRaises(ValueError): + Ascii(required=True, max_length=0) + + +class TestText(BaseCassEngTestCase): + + def test_min_length(self): + """ Test arbitrary minimal lengths requirements. """ Text(min_length=0).validate('') + Text(min_length=0).validate(None) + Text(min_length=0).validate('blake') + + Text(min_length=1).validate('b') + Text(min_length=5).validate('blake') Text(min_length=5).validate('blaketastic') + + with self.assertRaises(ValidationError): + Text(min_length=1).validate('') + + with self.assertRaises(ValidationError): + Text(min_length=1).validate(None) + + with self.assertRaises(ValidationError): + Text(min_length=6).validate('') + + with self.assertRaises(ValidationError): + Text(min_length=6).validate(None) + with self.assertRaises(ValidationError): Text(min_length=6).validate('blake') + with self.assertRaises(ValueError): + Text(min_length=-1) + def test_max_length(self): + """ Test arbitrary maximal lengths requirements. """ + Text(max_length=0).validate('') + Text(max_length=0).validate(None) + + Text(max_length=1).validate('') + Text(max_length=1).validate(None) + Text(max_length=1).validate('b') + Text(max_length=5).validate('') + Text(max_length=5).validate(None) + Text(max_length=5).validate('b') Text(max_length=5).validate('blake') + + with self.assertRaises(ValidationError): + Text(max_length=0).validate('b') + with self.assertRaises(ValidationError): Text(max_length=5).validate('blaketastic') + with self.assertRaises(ValueError): + Text(max_length=-1) + + def test_length_range(self): + Text(min_length=0, max_length=0) + Text(min_length=0, max_length=1) + Text(min_length=10, max_length=10) + Text(min_length=10, max_length=11) + + with self.assertRaises(ValueError): + Text(min_length=10, max_length=9) + + with self.assertRaises(ValueError): + Text(min_length=1, max_length=0) + def test_type_checking(self): Text().validate('string') Text().validate(u'unicode') Text().validate(bytearray('bytearray', encoding='ascii')) - with self.assertRaises(ValidationError): - Text(required=True).validate(None) - with self.assertRaises(ValidationError): Text().validate(5) with self.assertRaises(ValidationError): Text().validate(True) + Text().validate("!#$%&\'()*+,-./") + Text().validate('Beyonc' + chr(233)) + if sys.version_info < (3, 1): + Text().validate(u'Beyonc' + unichr(233)) + + def test_unaltering_validation(self): + """ Test the validation step doesn't re-interpret values. """ + self.assertEqual(Text().validate(''), '') + self.assertEqual(Text().validate(None), None) + self.assertEqual(Text().validate('yo'), 'yo') + def test_non_required_validation(self): """ Tests that validation is ok on none and blank values if required is False """ Text().validate('') Text().validate(None) + def test_required_validation(self): + """ Tests that validation raise on none and blank values if value required. """ + Text(required=True).validate('b') + + with self.assertRaises(ValidationError): + Text(required=True).validate('') + + with self.assertRaises(ValidationError): + Text(required=True).validate(None) + + # With min_length set. + Text(required=True, min_length=0).validate('b') + Text(required=True, min_length=1).validate('b') + + with self.assertRaises(ValidationError): + Text(required=True, min_length=2).validate('b') + + # With max_length set. + Text(required=True, max_length=1).validate('b') + + with self.assertRaises(ValidationError): + Text(required=True, max_length=2).validate('blake') + + with self.assertRaises(ValueError): + Text(required=True, max_length=0) + class TestExtraFieldsRaiseException(BaseCassEngTestCase): class TestModel(Model): diff --git a/tests/integration/cqlengine/model/test_class_construction.py b/tests/integration/cqlengine/model/test_class_construction.py index 8147e41079..e447056376 100644 --- a/tests/integration/cqlengine/model/test_class_construction.py +++ b/tests/integration/cqlengine/model/test_class_construction.py @@ -47,9 +47,30 @@ class TestModel(Model): inst = TestModel() self.assertHasAttr(inst, 'id') self.assertHasAttr(inst, 'text') - self.assertIsNone(inst.id) + self.assertIsNotNone(inst.id) self.assertIsNone(inst.text) + def test_values_on_instantiation(self): + """ + Tests defaults and user-provided values on instantiation. + """ + + class TestPerson(Model): + first_name = columns.Text(primary_key=True, default='kevin') + last_name = columns.Text(default='deldycke') + + # Check that defaults are available at instantiation. + inst1 = TestPerson() + self.assertHasAttr(inst1, 'first_name') + self.assertHasAttr(inst1, 'last_name') + self.assertEqual(inst1.first_name, 'kevin') + self.assertEqual(inst1.last_name, 'deldycke') + + # Check that values on instantiation overrides defaults. + inst2 = TestPerson(first_name='bob', last_name='joe') + self.assertEqual(inst2.first_name, 'bob') + self.assertEqual(inst2.last_name, 'joe') + def test_db_map(self): """ Tests that the db_map is properly defined diff --git a/tests/integration/cqlengine/model/test_model.py b/tests/integration/cqlengine/model/test_model.py index e46698ff75..b31b8d5aee 100644 --- a/tests/integration/cqlengine/model/test_model.py +++ b/tests/integration/cqlengine/model/test_model.py @@ -22,7 +22,8 @@ from cassandra.cqlengine.management import sync_table, drop_table, create_keyspace_simple, drop_keyspace from cassandra.cqlengine import models from cassandra.cqlengine.models import Model, ModelDefinitionException - +from uuid import uuid1 +from tests.integration import pypy class TestModel(unittest.TestCase): """ Tests the non-io functionality of models """ @@ -172,4 +173,37 @@ class IllegalFilterColumnModel(Model): my_primary_key = columns.Integer(primary_key=True) filter = columns.Text() +@pypy +class ModelOverWriteTest(unittest.TestCase): + + def test_model_over_write(self): + """ + Test to ensure overwriting of primary keys in model inheritance is allowed + + This is currently only an issue in PyPy. When PYTHON-504 is introduced this should + be updated error out and warn the user + + @since 3.6.0 + @jira_ticket PYTHON-576 + @expected_result primary keys can be overwritten via inheritance + + @test_category object_mapper + """ + class TimeModelBase(Model): + uuid = columns.TimeUUID(primary_key=True) + + class DerivedTimeModel(TimeModelBase): + __table_name__ = 'derived_time' + uuid = columns.TimeUUID(primary_key=True, partition_key=True) + value = columns.Text(required=False) + + # In case the table already exists in keyspace + drop_table(DerivedTimeModel) + + sync_table(DerivedTimeModel) + uuid_value = uuid1() + uuid_value2 = uuid1() + DerivedTimeModel.create(uuid=uuid_value, value="first") + DerivedTimeModel.create(uuid=uuid_value2, value="second") + DerivedTimeModel.objects.filter(uuid=uuid_value) diff --git a/tests/integration/cqlengine/model/test_model_io.py b/tests/integration/cqlengine/model/test_model_io.py index 3faf62febc..c5fd5e37ca 100644 --- a/tests/integration/cqlengine/model/test_model_io.py +++ b/tests/integration/cqlengine/model/test_model_io.py @@ -38,8 +38,6 @@ from tests.integration.cqlengine import DEFAULT_KEYSPACE - - class TestModel(Model): id = columns.UUID(primary_key=True, default=lambda: uuid4()) @@ -72,7 +70,7 @@ def tearDownClass(cls): def test_model_save_and_load(self): """ - Tests that models can be saved and retrieved + Tests that models can be saved and retrieved, using the create method. """ tm = TestModel.create(count=8, text='123456789') self.assertIsInstance(tm, TestModel) @@ -83,6 +81,22 @@ def test_model_save_and_load(self): for cname in tm._columns.keys(): self.assertEqual(getattr(tm, cname), getattr(tm2, cname)) + def test_model_instantiation_save_and_load(self): + """ + Tests that models can be saved and retrieved, this time using the + natural model instantiation. + """ + tm = TestModel(count=8, text='123456789') + # Tests that values are available on instantiation. + self.assertIsNotNone(tm['id']) + self.assertEqual(tm.count, 8) + self.assertEqual(tm.text, '123456789') + tm.save() + tm2 = TestModel.objects(id=tm.id).first() + + for cname in tm._columns.keys(): + self.assertEqual(getattr(tm, cname), getattr(tm2, cname)) + def test_model_read_as_dict(self): """ Tests that columns of an instance can be read as a dict. @@ -468,6 +482,49 @@ def test_previous_value_tracking_on_instantiation(self): self.assertTrue(self.instance._values['count'].previous_value is None) self.assertTrue(self.instance.count is None) + def test_previous_value_tracking_on_instantiation_with_default(self): + + class TestDefaultValueTracking(Model): + id = columns.Integer(partition_key=True) + int1 = columns.Integer(default=123) + int2 = columns.Integer(default=456) + int3 = columns.Integer(default=lambda: random.randint(0, 1000)) + int4 = columns.Integer(default=lambda: random.randint(0, 1000)) + int5 = columns.Integer() + int6 = columns.Integer() + + instance = TestDefaultValueTracking( + id=1, + int1=9999, + int3=7777, + int5=5555) + + self.assertEqual(instance.id, 1) + self.assertEqual(instance.int1, 9999) + self.assertEqual(instance.int2, 456) + self.assertEqual(instance.int3, 7777) + self.assertIsNotNone(instance.int4) + self.assertIsInstance(instance.int4, int) + self.assertGreaterEqual(instance.int4, 0) + self.assertLessEqual(instance.int4, 1000) + self.assertEqual(instance.int5, 5555) + self.assertTrue(instance.int6 is None) + + # All previous values are unset as the object hasn't been persisted + # yet. + self.assertTrue(instance._values['id'].previous_value is None) + self.assertTrue(instance._values['int1'].previous_value is None) + self.assertTrue(instance._values['int2'].previous_value is None) + self.assertTrue(instance._values['int3'].previous_value is None) + self.assertTrue(instance._values['int4'].previous_value is None) + self.assertTrue(instance._values['int5'].previous_value is None) + self.assertTrue(instance._values['int6'].previous_value is None) + + # All explicitely set columns, and those with default values are + # flagged has changed. + self.assertTrue(set(instance.get_changed_columns()) == set([ + 'id', 'int1', 'int2', 'int3', 'int4', 'int5'])) + def test_save_to_none(self): """ Test update of column value of None with save() function. diff --git a/tests/integration/cqlengine/model/test_updates.py b/tests/integration/cqlengine/model/test_updates.py index 242bffe12f..bc39d142cf 100644 --- a/tests/integration/cqlengine/model/test_updates.py +++ b/tests/integration/cqlengine/model/test_updates.py @@ -79,8 +79,8 @@ def test_update_values(self): self.assertEqual(m2.count, m1.count) self.assertEqual(m2.text, m0.text) - def test_noop_model_update(self): - """ tests that calling update on a model with no changes will do nothing. """ + def test_noop_model_direct_update(self): + """ Tests that calling update on a model with no changes will do nothing. """ m0 = TestUpdateModel.create(count=5, text='monkey') with patch.object(self.session, 'execute') as execute: @@ -91,6 +91,38 @@ def test_noop_model_update(self): m0.update(count=5) assert execute.call_count == 0 + with self.assertRaises(ValidationError): + m0.update(partition=m0.partition) + + with self.assertRaises(ValidationError): + m0.update(cluster=m0.cluster) + + def test_noop_model_assignation_update(self): + """ Tests that assigning the same value on a model will do nothing. """ + # Create object and fetch it back to eliminate any hidden variable + # cache effect. + m0 = TestUpdateModel.create(count=5, text='monkey') + m1 = TestUpdateModel.get(partition=m0.partition, cluster=m0.cluster) + + with patch.object(self.session, 'execute') as execute: + m1.save() + assert execute.call_count == 0 + + with patch.object(self.session, 'execute') as execute: + m1.count = 5 + m1.save() + assert execute.call_count == 0 + + with patch.object(self.session, 'execute') as execute: + m1.partition = m0.partition + m1.save() + assert execute.call_count == 0 + + with patch.object(self.session, 'execute') as execute: + m1.cluster = m0.cluster + m1.save() + assert execute.call_count == 0 + def test_invalid_update_kwarg(self): """ tests that passing in a kwarg to the update method that isn't a column will fail """ m0 = TestUpdateModel.create(count=5, text='monkey') diff --git a/tests/integration/cqlengine/query/test_named.py b/tests/integration/cqlengine/query/test_named.py index 9cddbece17..55129cb985 100644 --- a/tests/integration/cqlengine/query/test_named.py +++ b/tests/integration/cqlengine/query/test_named.py @@ -342,7 +342,7 @@ def test_named_table_with_mv(self): # Populate the base table with data prepared_insert = self.session.prepare("""INSERT INTO {0}.scores (user, game, year, month, day, score) VALUES (?, ?, ? ,? ,?, ?)""".format(ks)) - parameters = {('pcmanus', 'Coup', 2015, 5, 1, 4000), + parameters = (('pcmanus', 'Coup', 2015, 5, 1, 4000), ('jbellis', 'Coup', 2015, 5, 3, 1750), ('yukim', 'Coup', 2015, 5, 3, 2250), ('tjake', 'Coup', 2015, 5, 3, 500), @@ -353,7 +353,7 @@ def test_named_table_with_mv(self): ('jbellis', 'Coup', 2015, 6, 20, 3500), ('jbellis', 'Checkers', 2015, 6, 20, 1200), ('jbellis', 'Chess', 2015, 6, 21, 3500), - ('pcmanus', 'Chess', 2015, 1, 25, 3200)} + ('pcmanus', 'Chess', 2015, 1, 25, 3200)) prepared_insert.consistency_level = ConsistencyLevel.ALL execute_concurrent_with_args(self.session, prepared_insert, parameters) diff --git a/tests/integration/cqlengine/query/test_queryoperators.py b/tests/integration/cqlengine/query/test_queryoperators.py index c2a2a74206..055e8f3db2 100644 --- a/tests/integration/cqlengine/query/test_queryoperators.py +++ b/tests/integration/cqlengine/query/test_queryoperators.py @@ -72,7 +72,7 @@ def tearDown(self): super(TestTokenFunction, self).tearDown() drop_table(TokenTestModel) - @execute_count(14) + @execute_count(15) def test_token_function(self): """ Tests that token functions work properly """ assert TokenTestModel.objects().count() == 0 @@ -91,6 +91,10 @@ def test_token_function(self): assert len(seen_keys) == 10 assert all([i in seen_keys for i in range(10)]) + # pk__token equality + r = TokenTestModel.objects(pk__token=functions.Token(last_token)) + self.assertEqual(len(r), 1) + def test_compound_pk_token_function(self): class TestModel(Model): diff --git a/tests/integration/cqlengine/query/test_queryset.py b/tests/integration/cqlengine/query/test_queryset.py index 0776d67943..ea303373b8 100644 --- a/tests/integration/cqlengine/query/test_queryset.py +++ b/tests/integration/cqlengine/query/test_queryset.py @@ -268,6 +268,7 @@ def test_defining_defer_fields(self): @since 3.5 @jira_ticket PYTHON-560 + @jira_ticket PYTHON-599 @expected_result deferred fields should not be returned @test_category object_mapper @@ -300,6 +301,10 @@ def test_defining_defer_fields(self): q = TestModel.objects.filter(test_id=0) self.assertEqual(q._select_fields(), ['attempt_id', 'description', 'expected_result', 'test_result']) + # when all fields are defered, it fallbacks select the partition keys + q = TestModel.objects.defer(['test_id', 'attempt_id', 'description', 'expected_result', 'test_result']) + self.assertEqual(q._select_fields(), ['test_id']) + class BaseQuerySetUsage(BaseCassEngTestCase): @@ -847,16 +852,12 @@ def test_tzaware_datetime_support(self): def test_success_case(self): """ Test that the min and max time uuid functions work as expected """ pk = uuid4() - TimeUUIDQueryModel.create(partition=pk, time=uuid1(), data='1') - time.sleep(0.2) - TimeUUIDQueryModel.create(partition=pk, time=uuid1(), data='2') - time.sleep(0.2) - midpoint = datetime.utcnow() - time.sleep(0.2) - TimeUUIDQueryModel.create(partition=pk, time=uuid1(), data='3') - time.sleep(0.2) - TimeUUIDQueryModel.create(partition=pk, time=uuid1(), data='4') - time.sleep(0.2) + startpoint = datetime.utcnow() + TimeUUIDQueryModel.create(partition=pk, time=uuid_from_time(startpoint + timedelta(seconds=1)), data='1') + TimeUUIDQueryModel.create(partition=pk, time=uuid_from_time(startpoint + timedelta(seconds=2)), data='2') + midpoint = startpoint + timedelta(seconds=3) + TimeUUIDQueryModel.create(partition=pk, time=uuid_from_time(startpoint + timedelta(seconds=4)), data='3') + TimeUUIDQueryModel.create(partition=pk, time=uuid_from_time(startpoint + timedelta(seconds=5)), data='4') # test kwarg filtering q = TimeUUIDQueryModel.filter(partition=pk, time__lte=functions.MaxTimeUUID(midpoint)) @@ -894,7 +895,6 @@ def test_success_case(self): class TestInOperator(BaseQuerySetUsage): - @execute_count(1) def test_kwarg_success_case(self): """ Tests the in operator works with the kwarg query method """ @@ -907,6 +907,51 @@ def test_query_expression_success_case(self): q = TestModel.filter(TestModel.test_id.in_([0, 1])) assert q.count() == 8 + @execute_count(5) + def test_bool(self): + """ + Adding coverage to cqlengine for bool types. + + @since 3.6 + @jira_ticket PYTHON-596 + @expected_result bool results should be filtered appropriately + + @test_category object_mapper + """ + class bool_model(Model): + k = columns.Integer(primary_key=True) + b = columns.Boolean(primary_key=True) + v = columns.Integer(default=3) + sync_table(bool_model) + + bool_model.create(k=0, b=True) + bool_model.create(k=0, b=False) + self.assertEqual(len(bool_model.objects.all()), 2) + self.assertEqual(len(bool_model.objects.filter(k=0, b=True)), 1) + self.assertEqual(len(bool_model.objects.filter(k=0, b=False)), 1) + + @execute_count(3) + def test_bool_filter(self): + """ + Test to ensure that we don't translate boolean objects to String unnecessarily in filter clauses + + @since 3.6 + @jira_ticket PYTHON-596 + @expected_result We should not receive a server error + + @test_category object_mapper + """ + class bool_model2(Model): + k = columns.Boolean(primary_key=True) + b = columns.Integer(primary_key=True) + v = columns.Text() + drop_table(bool_model2) + sync_table(bool_model2) + + bool_model2.create(k=True, b=1, v='a') + bool_model2.create(k=False, b=1, v='b') + self.assertEqual(len(list(bool_model2.objects(k__in=(True, False)))), 2) + @greaterthancass20 class TestContainsOperator(BaseQuerySetUsage): @@ -1353,5 +1398,3 @@ def test_defaultFetchSize(self): smiths = list(People2.filter(last_name="Smith")) self.assertEqual(len(smiths), 5) self.assertTrue(smiths[0].last_name is not None) - - diff --git a/tests/integration/cqlengine/test_context_query.py b/tests/integration/cqlengine/test_context_query.py new file mode 100644 index 0000000000..b3941319e9 --- /dev/null +++ b/tests/integration/cqlengine/test_context_query.py @@ -0,0 +1,127 @@ +# Copyright 2013-2016 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cassandra.cqlengine import columns +from cassandra.cqlengine.management import drop_keyspace, sync_table, create_keyspace_simple +from cassandra.cqlengine.models import Model +from cassandra.cqlengine.query import ContextQuery +from tests.integration.cqlengine.base import BaseCassEngTestCase + + +class TestModel(Model): + + __keyspace__ = 'ks1' + + partition = columns.Integer(primary_key=True) + cluster = columns.Integer(primary_key=True) + count = columns.Integer() + text = columns.Text() + + +class ContextQueryTests(BaseCassEngTestCase): + + KEYSPACES = ('ks1', 'ks2', 'ks3', 'ks4') + + @classmethod + def setUpClass(cls): + super(ContextQueryTests, cls).setUpClass() + for ks in cls.KEYSPACES: + create_keyspace_simple(ks, 1) + sync_table(TestModel, keyspaces=cls.KEYSPACES) + + @classmethod + def tearDownClass(cls): + super(ContextQueryTests, cls).tearDownClass() + for ks in cls.KEYSPACES: + drop_keyspace(ks) + + def setUp(self): + super(ContextQueryTests, self).setUp() + for ks in self.KEYSPACES: + with ContextQuery(TestModel, keyspace=ks) as tm: + for obj in tm.all(): + obj.delete() + + def test_context_manager(self): + """ + Validates that when a context query is constructed that the + keyspace of the returned model is toggled appropriately + + @since 3.6 + @jira_ticket PYTHON-598 + @expected_result default keyspace should be used + + @test_category query + """ + # model keyspace write/read + for ks in self.KEYSPACES: + with ContextQuery(TestModel, keyspace=ks) as tm: + self.assertEqual(tm.__keyspace__, ks) + + self.assertEqual(TestModel._get_keyspace(), 'ks1') + + def test_default_keyspace(self): + """ + Tests the use of context queries with the default model keyspsace + + @since 3.6 + @jira_ticket PYTHON-598 + @expected_result default keyspace should be used + + @test_category query + """ + # model keyspace write/read + for i in range(5): + TestModel.objects.create(partition=i, cluster=i) + + with ContextQuery(TestModel) as tm: + self.assertEqual(5, len(tm.objects.all())) + + with ContextQuery(TestModel, keyspace='ks1') as tm: + self.assertEqual(5, len(tm.objects.all())) + + for ks in self.KEYSPACES[1:]: + with ContextQuery(TestModel, keyspace=ks) as tm: + self.assertEqual(0, len(tm.objects.all())) + + def test_context_keyspace(self): + """ + Tests the use of context queries with non default keyspaces + + @since 3.6 + @jira_ticket PYTHON-598 + @expected_result queries should be routed to appropriate keyspaces + + @test_category query + """ + for i in range(5): + with ContextQuery(TestModel, keyspace='ks4') as tm: + tm.objects.create(partition=i, cluster=i) + + with ContextQuery(TestModel, keyspace='ks4') as tm: + self.assertEqual(5, len(tm.objects.all())) + + self.assertEqual(0, len(TestModel.objects.all())) + + for ks in self.KEYSPACES[:2]: + with ContextQuery(TestModel, keyspace=ks) as tm: + self.assertEqual(0, len(tm.objects.all())) + + # simple data update + with ContextQuery(TestModel, keyspace='ks4') as tm: + obj = tm.objects.get(partition=1) + obj.update(count=42) + + self.assertEqual(42, tm.objects.get(partition=1).count) + diff --git a/tests/integration/cqlengine/test_lwt_conditional.py b/tests/integration/cqlengine/test_lwt_conditional.py index d273df9cc0..8395154c34 100644 --- a/tests/integration/cqlengine/test_lwt_conditional.py +++ b/tests/integration/cqlengine/test_lwt_conditional.py @@ -234,3 +234,18 @@ def test_update_to_none(self): self.assertIsNotNone(TestConditionalModel.objects(id=t.id).first().text) TestConditionalModel.objects(id=t.id).iff(count=5).update(text=None) self.assertIsNone(TestConditionalModel.objects(id=t.id).first().text) + + def test_column_delete_after_update(self): + # DML path + t = TestConditionalModel.create(text='something', count=5) + t.iff(count=5).update(text=None, count=6) + + self.assertIsNone(t.text) + self.assertEqual(t.count, 6) + + # QuerySet path + t = TestConditionalModel.create(text='something', count=5) + TestConditionalModel.objects(id=t.id).iff(count=5).update(text=None, count=6) + + self.assertIsNone(TestConditionalModel.objects(id=t.id).first().text) + self.assertEqual(TestConditionalModel.objects(id=t.id).first().count, 6) diff --git a/tests/integration/cqlengine/test_ttl.py b/tests/integration/cqlengine/test_ttl.py index ba2c1e0935..3e16292781 100644 --- a/tests/integration/cqlengine/test_ttl.py +++ b/tests/integration/cqlengine/test_ttl.py @@ -18,6 +18,7 @@ except ImportError: import unittest # noqa +from cassandra import InvalidRequest from cassandra.cqlengine.management import sync_table, drop_table from tests.integration.cqlengine.base import BaseCassEngTestCase from cassandra.cqlengine.models import Model @@ -158,6 +159,16 @@ def test_ttl_included_with_blind_update(self): @unittest.skipIf(CASSANDRA_VERSION < '2.0', "default_time_to_Live was introduce in C* 2.0, currently running {0}".format(CASSANDRA_VERSION)) class TTLDefaultTest(BaseDefaultTTLTest): + def get_default_ttl(self, table_name): + session = get_session() + try: + default_ttl = session.execute("SELECT default_time_to_live FROM system_schema.tables " + "WHERE keyspace_name = 'cqlengine_test' AND table_name = '{0}'".format(table_name)) + except InvalidRequest: + default_ttl = session.execute("SELECT default_time_to_live FROM system.schema_columnfamilies " + "WHERE keyspace_name = 'cqlengine_test' AND columnfamily_name = '{0}'".format(table_name)) + return default_ttl[0]['default_time_to_live'] + def test_default_ttl_not_set(self): session = get_session() @@ -166,6 +177,9 @@ def test_default_ttl_not_set(self): self.assertIsNone(o._ttl) + default_ttl = self.get_default_ttl('test_ttlmodel') + self.assertEqual(default_ttl, 0) + with mock.patch.object(session, 'execute') as m: TestTTLModel.objects(id=tid).update(text="aligators") @@ -174,23 +188,44 @@ def test_default_ttl_not_set(self): def test_default_ttl_set(self): session = get_session() + o = TestDefaultTTLModel.create(text="some text on ttl") tid = o.id - self.assertEqual(o._ttl, TestDefaultTTLModel.__default_ttl__) + # Should not be set, it's handled by Cassandra + self.assertIsNone(o._ttl) + + default_ttl = self.get_default_ttl('test_default_ttlmodel') + self.assertEqual(default_ttl, 20) with mock.patch.object(session, 'execute') as m: - TestDefaultTTLModel.objects(id=tid).update(text="aligators expired") + TestTTLModel.objects(id=tid).update(text="aligators expired") + # Should not be set either query = m.call_args[0][0].query_string - self.assertIn("USING TTL", query) + self.assertNotIn("USING TTL", query) + + def test_default_ttl_modify(self): + session = get_session() + + default_ttl = self.get_default_ttl('test_default_ttlmodel') + self.assertEqual(default_ttl, 20) + + TestDefaultTTLModel.__options__ = {'default_time_to_live': 10} + sync_table(TestDefaultTTLModel) + + default_ttl = self.get_default_ttl('test_default_ttlmodel') + self.assertEqual(default_ttl, 10) + + # Restore default TTL + TestDefaultTTLModel.__options__ = {'default_time_to_live': 20} + sync_table(TestDefaultTTLModel) def test_override_default_ttl(self): session = get_session() o = TestDefaultTTLModel.create(text="some text on ttl") tid = o.id - self.assertEqual(o._ttl, TestDefaultTTLModel.__default_ttl__) o.ttl(3600) self.assertEqual(o._ttl, 3600) diff --git a/tests/integration/long/test_ssl.py b/tests/integration/long/test_ssl.py index 9ae84f38d6..99875afe48 100644 --- a/tests/integration/long/test_ssl.py +++ b/tests/integration/long/test_ssl.py @@ -17,7 +17,7 @@ except ImportError: import unittest -import os, sys, traceback, logging, ssl +import os, sys, traceback, logging, ssl, time from cassandra.cluster import Cluster, NoHostAvailable from cassandra import ConsistencyLevel from cassandra.query import SimpleStatement @@ -86,7 +86,7 @@ def validate_ssl_options(ssl_options): raise RuntimeError("Failed to connect to SSL cluster after 5 attempts") try: cluster = Cluster(protocol_version=PROTOCOL_VERSION, ssl_options=ssl_options) - session = cluster.connect() + session = cluster.connect(wait_for_all_pools=True) break except Exception: ex_type, ex, tb = sys.exc_info() @@ -132,11 +132,47 @@ def test_can_connect_with_ssl_ca(self): @test_category connection:ssl """ + # find absolute path to client CA_CERTS + abs_path_ca_cert_path = os.path.abspath(CLIENT_CA_CERTS) + ssl_options = {'ca_certs': abs_path_ca_cert_path,'ssl_version': ssl.PROTOCOL_TLSv1} + validate_ssl_options(ssl_options=ssl_options) + + def test_can_connect_with_ssl_long_running(self): + """ + Test to validate that long running ssl connections continue to function past thier timeout window + + @since 3.6.0 + @jira_ticket PYTHON-600 + @expected_result The client can connect via SSL and preform some basic operations over a period of longer then a minute + + @test_category connection:ssl + """ + # find absolute path to client CA_CERTS abs_path_ca_cert_path = os.path.abspath(CLIENT_CA_CERTS) ssl_options = {'ca_certs': abs_path_ca_cert_path, 'ssl_version': ssl.PROTOCOL_TLSv1} - validate_ssl_options(ssl_options=ssl_options) + tries = 0 + while True: + if tries > 5: + raise RuntimeError("Failed to connect to SSL cluster after 5 attempts") + try: + cluster = Cluster(protocol_version=PROTOCOL_VERSION, ssl_options=ssl_options) + session = cluster.connect(wait_for_all_pools=True) + break + except Exception: + ex_type, ex, tb = sys.exc_info() + log.warn("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) + del tb + tries += 1 + + # attempt a few simple commands. + + for i in range(8): + rs = session.execute("SELECT * FROM system.local") + time.sleep(10) + + cluster.shutdown() def test_can_connect_with_ssl_ca_host_match(self): """ diff --git a/tests/integration/standard/test_cluster.py b/tests/integration/standard/test_cluster.py index 62244b93f3..4c8339a2cf 100644 --- a/tests/integration/standard/test_cluster.py +++ b/tests/integration/standard/test_cluster.py @@ -33,7 +33,8 @@ from cassandra.protocol import MAX_SUPPORTED_VERSION from cassandra.query import SimpleStatement, TraceUnavailable, tuple_factory -from tests.integration import use_singledc, PROTOCOL_VERSION, get_server_versions, get_node, CASSANDRA_VERSION, execute_until_pass, execute_with_long_wait_retry, get_node, MockLoggingHandler +from tests.integration import use_singledc, PROTOCOL_VERSION, get_server_versions, get_node, CASSANDRA_VERSION, execute_until_pass, execute_with_long_wait_retry, get_node,\ + MockLoggingHandler, get_unsupported_lower_protocol, get_unsupported_upper_protocol from tests.integration.util import assert_quiescent_pool_state @@ -41,8 +42,40 @@ def setup_module(): use_singledc() +class IgnoredHostPolicy(RoundRobinPolicy): + + def __init__(self, ignored_hosts): + self.ignored_hosts = ignored_hosts + RoundRobinPolicy.__init__(self) + + def distance(self, host): + if(str(host) in self.ignored_hosts): + return HostDistance.IGNORED + else: + return HostDistance.LOCAL + + class ClusterTests(unittest.TestCase): + def test_ignored_host_up(self): + """ + Test to ensure that is_up is not set by default on ignored hosts + + @since 3.6 + @jira_ticket PYTHON-551 + @expected_result ignored hosts should have None set for is_up + + @test_category connection + """ + ingored_host_policy = IgnoredHostPolicy(["127.0.0.2", "127.0.0.3"]) + cluster = Cluster(protocol_version=PROTOCOL_VERSION, load_balancing_policy=ingored_host_policy) + session = cluster.connect() + for host in cluster.metadata.all_hosts(): + if str(host) == "127.0.0.1": + self.assertTrue(host.is_up) + else: + self.assertIsNone(host.is_up) + def test_host_resolution(self): """ Test to insure A records are resolved appropriately. @@ -67,11 +100,11 @@ def test_host_duplication(self): @test_category connection """ cluster = Cluster(contact_points=["localhost", "127.0.0.1", "localhost", "localhost", "localhost"], protocol_version=PROTOCOL_VERSION, connect_timeout=1) - cluster.connect() + cluster.connect(wait_for_all_pools=True) self.assertEqual(len(cluster.metadata.all_hosts()), 3) cluster.shutdown() cluster = Cluster(contact_points=["127.0.0.1", "localhost"], protocol_version=PROTOCOL_VERSION, connect_timeout=1) - cluster.connect() + cluster.connect(wait_for_all_pools=True) self.assertEqual(len(cluster.metadata.all_hosts()), 3) cluster.shutdown() @@ -175,6 +208,42 @@ def test_protocol_negotiation(self): cluster.shutdown() + def test_invalid_protocol_negotation(self): + """ + Test for protocol negotiation when explicit versions are set + + If an explicit protocol version that is not compatible with the server version is set + an exception should be thrown. It should not attempt to negotiate + + for reference supported protocol version to server versions is as follows/ + + 1.2 -> 1 + 2.0 -> 2, 1 + 2.1 -> 3, 2, 1 + 2.2 -> 4, 3, 2, 1 + 3.X -> 4, 3 + + @since 3.6.0 + @jira_ticket PYTHON-537 + @expected_result downgrading should not be allowed when explicit protocol versions are set. + + @test_category connection + """ + + upper_bound = get_unsupported_upper_protocol() + if upper_bound is not None: + cluster = Cluster(protocol_version=upper_bound) + with self.assertRaises(NoHostAvailable): + cluster.connect() + cluster.shutdown() + + lower_bound = get_unsupported_lower_protocol() + if lower_bound is not None: + cluster = Cluster(protocol_version=lower_bound) + with self.assertRaises(NoHostAvailable): + cluster.connect() + cluster.shutdown() + def test_connect_on_keyspace(self): """ Ensure clusters that connect on a keyspace, do @@ -516,14 +585,14 @@ def test_idle_heartbeat(self): cluster = Cluster(protocol_version=PROTOCOL_VERSION, idle_heartbeat_interval=interval) if PROTOCOL_VERSION < 3: cluster.set_core_connections_per_host(HostDistance.LOCAL, 1) - session = cluster.connect() + session = cluster.connect(wait_for_all_pools=True) # This test relies on impl details of connection req id management to see if heartbeats # are being sent. May need update if impl is changed connection_request_ids = {} for h in cluster.get_connection_holders(): for c in h.get_connections(): - # make sure none are idle (should have startup messages) + # make sure none are idle (should have startup messages self.assertFalse(c.is_idle) with c.lock: connection_request_ids[id(c)] = deque(c.request_ids) # copy of request ids @@ -558,7 +627,7 @@ def test_idle_heartbeat(self): self.assertEqual(len(holders), len(cluster.metadata.all_hosts()) + 1) # hosts pools, 1 for cc # include additional sessions - session2 = cluster.connect() + session2 = cluster.connect(wait_for_all_pools=True) holders = cluster.get_connection_holders() self.assertIn(cluster.control_connection, holders) @@ -631,7 +700,7 @@ def test_profile_load_balancing(self): query = "select release_version from system.local" node1 = ExecutionProfile(load_balancing_policy=WhiteListRoundRobinPolicy(['127.0.0.1'])) with Cluster(execution_profiles={'node1': node1}) as cluster: - session = cluster.connect() + session = cluster.connect(wait_for_all_pools=True) # default is DCA RR for all hosts expected_hosts = set(cluster.metadata.all_hosts()) @@ -688,7 +757,7 @@ def test_profile_lb_swap(self): rr2 = ExecutionProfile(load_balancing_policy=RoundRobinPolicy()) exec_profiles = {'rr1': rr1, 'rr2': rr2} with Cluster(execution_profiles=exec_profiles) as cluster: - session = cluster.connect() + session = cluster.connect(wait_for_all_pools=True) # default is DCA RR for all hosts expected_hosts = set(cluster.metadata.all_hosts()) @@ -780,7 +849,7 @@ def test_profile_pool_management(self): node1 = ExecutionProfile(load_balancing_policy=WhiteListRoundRobinPolicy(['127.0.0.1'])) node2 = ExecutionProfile(load_balancing_policy=WhiteListRoundRobinPolicy(['127.0.0.2'])) with Cluster(execution_profiles={EXEC_PROFILE_DEFAULT: node1, 'node2': node2}) as cluster: - session = cluster.connect() + session = cluster.connect(wait_for_all_pools=True) pools = session.get_pool_state() # there are more hosts, but we connected to the ones in the lbp aggregate self.assertGreater(len(cluster.metadata.all_hosts()), 2) @@ -805,7 +874,7 @@ def test_add_profile_timeout(self): node1 = ExecutionProfile(load_balancing_policy=WhiteListRoundRobinPolicy(['127.0.0.1'])) with Cluster(execution_profiles={EXEC_PROFILE_DEFAULT: node1}) as cluster: - session = cluster.connect() + session = cluster.connect(wait_for_all_pools=True) pools = session.get_pool_state() self.assertGreater(len(cluster.metadata.all_hosts()), 2) self.assertEqual(set(h.address for h in pools), set(('127.0.0.1',))) diff --git a/tests/integration/standard/test_connection.py b/tests/integration/standard/test_connection.py index 2d07b92038..69566c80ad 100644 --- a/tests/integration/standard/test_connection.py +++ b/tests/integration/standard/test_connection.py @@ -99,7 +99,7 @@ class HeartbeatTest(unittest.TestCase): def setUp(self): self.cluster = Cluster(protocol_version=PROTOCOL_VERSION, idle_heartbeat_interval=1) - self.session = self.cluster.connect() + self.session = self.cluster.connect(wait_for_all_pools=True) def tearDown(self): self.cluster.shutdown() diff --git a/tests/integration/standard/test_custom_protocol_handler.py b/tests/integration/standard/test_custom_protocol_handler.py index 63a8380902..c6818f7f4b 100644 --- a/tests/integration/standard/test_custom_protocol_handler.py +++ b/tests/integration/standard/test_custom_protocol_handler.py @@ -126,7 +126,7 @@ class CustomResultMessageRaw(ResultMessage): type_codes = my_type_codes @classmethod - def recv_results_rows(cls, f, protocol_version, user_type_map): + def recv_results_rows(cls, f, protocol_version, user_type_map, result_metadata): paging_state, column_metadata = cls.recv_results_metadata(f, user_type_map) rowcount = read_int(f) rows = [cls.recv_row(f, len(column_metadata)) for _ in range(rowcount)] @@ -155,7 +155,7 @@ class CustomResultMessageTracked(ResultMessage): checked_rev_row_set = set() @classmethod - def recv_results_rows(cls, f, protocol_version, user_type_map): + def recv_results_rows(cls, f, protocol_version, user_type_map, result_metadata): paging_state, column_metadata = cls.recv_results_metadata(f, user_type_map) rowcount = read_int(f) rows = [cls.recv_row(f, len(column_metadata)) for _ in range(rowcount)] diff --git a/tests/integration/standard/test_cython_protocol_handlers.py b/tests/integration/standard/test_cython_protocol_handlers.py index 3560709faa..7dc3db300e 100644 --- a/tests/integration/standard/test_cython_protocol_handlers.py +++ b/tests/integration/standard/test_cython_protocol_handlers.py @@ -7,10 +7,12 @@ except ImportError: import unittest -from cassandra import DriverException, Timeout, AlreadyExists +from itertools import count + from cassandra.query import tuple_factory from cassandra.cluster import Cluster, NoHostAvailable -from cassandra.protocol import ProtocolHandler, LazyProtocolHandler, NumpyProtocolHandler, ConfigurationException +from cassandra.concurrent import execute_concurrent_with_args +from cassandra.protocol import ProtocolHandler, LazyProtocolHandler, NumpyProtocolHandler from cassandra.cython_deps import HAVE_CYTHON, HAVE_NUMPY from tests.integration import use_singledc, PROTOCOL_VERSION, notprotocolv1, drop_keyspace_shutdown_cluster, VERIFY_CYTHON, BasicSharedKeyspaceUnitTestCase, execute_with_retry_tolerant, greaterthancass21 from tests.integration.datatype_utils import update_datatypes @@ -207,66 +209,49 @@ def verify_iterator_data(assertEqual, results): class NumpyNullTest(BasicSharedKeyspaceUnitTestCase): - # A dictionary containing table key to type. - # Boolean dictates whether or not the type can be deserialized with null value - NUMPY_TYPES = {"v1": ('bigint', False), - "v2": ('double', False), - "v3": ('float', False), - "v4": ('int', False), - "v5": ('smallint', False), - "v6": ("ascii", True), - "v7": ("blob", True), - "v8": ("boolean", True), - "v9": ("decimal", True), - "v10": ("inet", True), - "v11": ("text", True), - "v12": ("timestamp", True), - "v13": ("timeuuid", True), - "v14": ("uuid", True), - "v15": ("varchar", True), - "v16": ("varint", True), - } - - def setUp(self): - self.session.client_protocol_handler = NumpyProtocolHandler - self.session.row_factory = tuple_factory - @numpytest @greaterthancass21 def test_null_types(self): """ Test to validate that the numpy protocol handler can deal with null values. @since 3.3.0 + - updated 3.6.0: now numeric types used masked array @jira_ticket PYTHON-550 @expected_result Numpy can handle non mapped types' null values. @test_category data_types:serialization """ - - self.create_table_of_types() - self.session.execute("INSERT INTO {0}.{1} (k) VALUES (1)".format(self.keyspace_name, self.function_table_name)) - self.validate_types() - - def create_table_of_types(self): - """ - Builds a table containing all the numpy types - """ - base_ddl = '''CREATE TABLE {0}.{1} (k int PRIMARY KEY'''.format(self.keyspace_name, self.function_table_name, type) - for key, value in NumpyNullTest.NUMPY_TYPES.items(): - base_ddl = base_ddl+", {0} {1}".format(key, value[0]) - base_ddl = base_ddl+")" - execute_with_retry_tolerant(self.session, base_ddl, (DriverException, NoHostAvailable, Timeout), (ConfigurationException, AlreadyExists)) - - def validate_types(self): - """ - Selects each type from the table and expects either an exception or None depending on type - """ - for key, value in NumpyNullTest.NUMPY_TYPES.items(): - select = "SELECT {0} from {1}.{2}".format(key,self.keyspace_name, self.function_table_name) - if value[1]: - rs = execute_with_retry_tolerant(self.session, select, (NoHostAvailable), ()) - self.assertEqual(rs[0].get('v1'), None) + s = self.session + s.row_factory = tuple_factory + s.client_protocol_handler = NumpyProtocolHandler + + table = "%s.%s" % (self.keyspace_name, self.function_table_name) + create_table_with_all_types(table, s, 10) + + begin_unset = max(s.execute('select primkey from %s' % (table,))[0]['primkey']) + 1 + keys_null = range(begin_unset, begin_unset + 10) + + # scatter some emptry rows in here + insert = "insert into %s (primkey) values (%%s)" % (table,) + execute_concurrent_with_args(s, insert, ((k,) for k in keys_null)) + + result = s.execute("select * from %s" % (table,))[0] + + from numpy.ma import masked, MaskedArray + result_keys = result.pop('primkey') + mapped_index = [v[1] for v in sorted(zip(result_keys, count()))] + + had_masked = had_none = False + for col_array in result.values(): + # these have to be different branches (as opposed to comparing against an 'unset value') + # because None and `masked` have different identity and equals semantics + if isinstance(col_array, MaskedArray): + had_masked = True + [self.assertIsNot(col_array[i], masked) for i in mapped_index[:begin_unset]] + [self.assertIs(col_array[i], masked) for i in mapped_index[begin_unset:]] else: - with self.assertRaises(ValueError): - execute_with_retry_tolerant(self.session, select, (NoHostAvailable), ()) - + had_none = True + [self.assertIsNotNone(col_array[i]) for i in mapped_index[:begin_unset]] + [self.assertIsNone(col_array[i]) for i in mapped_index[begin_unset:]] + self.assertTrue(had_masked) + self.assertTrue(had_none) diff --git a/tests/integration/standard/test_metadata.py b/tests/integration/standard/test_metadata.py index c317e50c3e..598dd83971 100644 --- a/tests/integration/standard/test_metadata.py +++ b/tests/integration/standard/test_metadata.py @@ -79,6 +79,28 @@ def test_host_release_version(self): self.assertTrue(host.release_version.startswith(CASSANDRA_VERSION)) +class MetaDataRemovalTest(unittest.TestCase): + + def setUp(self): + self.cluster = Cluster(protocol_version=PROTOCOL_VERSION, contact_points=['127.0.0.1','127.0.0.2', '127.0.0.3', '126.0.0.186']) + self.cluster.connect() + + def tearDown(self): + self.cluster.shutdown() + + def test_bad_contact_point(self): + """ + Checks to ensure that hosts that are not resolvable are excluded from the contact point list. + + @since 3.6 + @jira_ticket PYTHON-549 + @expected_result Invalid hosts on the contact list should be excluded + + @test_category metadata + """ + self.assertEqual(len(self.cluster.metadata.all_hosts()), 3) + + class SchemaMetadataTests(BasicSegregatedKeyspaceUnitTestCase): def test_schema_metadata_disable(self): @@ -1133,14 +1155,12 @@ def test_legacy_tables(self): CREATE TABLE legacy.composite_comp_with_col ( key blob, - b blob, - s text, - t timeuuid, + column1 'org.apache.cassandra.db.marshal.DynamicCompositeType(b=>org.apache.cassandra.db.marshal.BytesType, s=>org.apache.cassandra.db.marshal.UTF8Type, t=>org.apache.cassandra.db.marshal.TimeUUIDType)', "b@6869746d65776974686d75736963" blob, "b@6d616d6d616a616d6d61" blob, - PRIMARY KEY (key, b, s, t) + PRIMARY KEY (key, column1) ) WITH COMPACT STORAGE - AND CLUSTERING ORDER BY (b ASC, s ASC, t ASC) + AND CLUSTERING ORDER BY (column1 ASC) AND caching = '{"keys":"ALL", "rows_per_partition":"NONE"}' AND comment = 'Stores file meta data' AND compaction = {'min_threshold': '4', 'class': 'org.apache.cassandra.db.compaction.SizeTieredCompactionStrategy', 'max_threshold': '32'} @@ -1253,20 +1273,13 @@ def test_legacy_tables(self): AND read_repair_chance = 0.0 AND speculative_retry = 'NONE'; -/* -Warning: Table legacy.composite_comp_no_col omitted because it has constructs not compatible with CQL (was created via legacy API). - -Approximate structure, for reference: -(this should not be used to reproduce this schema) - CREATE TABLE legacy.composite_comp_no_col ( key blob, - column1 'org.apache.cassandra.db.marshal.DynamicCompositeType(org.apache.cassandra.db.marshal.BytesType, org.apache.cassandra.db.marshal.UTF8Type, org.apache.cassandra.db.marshal.TimeUUIDType)', - column2 timeuuid, + column1 'org.apache.cassandra.db.marshal.DynamicCompositeType(b=>org.apache.cassandra.db.marshal.BytesType, s=>org.apache.cassandra.db.marshal.UTF8Type, t=>org.apache.cassandra.db.marshal.TimeUUIDType)', value blob, - PRIMARY KEY (key, column1, column1, column2) + PRIMARY KEY (key, column1) ) WITH COMPACT STORAGE - AND CLUSTERING ORDER BY (column1 ASC, column1 ASC, column2 ASC) + AND CLUSTERING ORDER BY (column1 ASC) AND caching = '{"keys":"ALL", "rows_per_partition":"NONE"}' AND comment = 'Stores file meta data' AND compaction = {'min_threshold': '4', 'class': 'org.apache.cassandra.db.compaction.SizeTieredCompactionStrategy', 'max_threshold': '32'} @@ -1278,8 +1291,7 @@ def test_legacy_tables(self): AND memtable_flush_period_in_ms = 0 AND min_index_interval = 128 AND read_repair_chance = 0.0 - AND speculative_retry = 'NONE'; -*/""" + AND speculative_retry = 'NONE';""" ccm = get_cluster() ccm.run_cli(cli_script) @@ -2035,7 +2047,31 @@ def test_bad_user_aggregate(self): self.assertIn("/*\nWarning:", m.export_as_string()) -class MaterializedViewMetadataTestSimple(BasicSharedKeyspaceUnitTestCase): +class DynamicCompositeTypeTest(BasicSharedKeyspaceUnitTestCase): + + def test_dct_alias(self): + """ + Tests to make sure DCT's have correct string formatting + + Constructs a DCT and check the format as generated. To insure it matches what is expected + + @since 3.6.0 + @jira_ticket PYTHON-579 + @expected_result DCT subtypes should always have fully qualified names + + @test_category metadata + """ + self.session.execute("CREATE TABLE {0}.{1} (" + "k int PRIMARY KEY," + "c1 'DynamicCompositeType(s => UTF8Type, i => Int32Type)'," + "c2 Text)".format(self.ks_name, self.function_table_name)) + dct_table = self.cluster.metadata.keyspaces.get(self.ks_name).tables.get(self.function_table_name) + + # Format can very slightly between versions, strip out whitespace for consistency sake + self.assertTrue("c1'org.apache.cassandra.db.marshal.DynamicCompositeType(s=>org.apache.cassandra.db.marshal.UTF8Type,i=>org.apache.cassandra.db.marshal.Int32Type)'" in dct_table.as_cql_query().replace(" ", "")) + + +class Materia3lizedViewMetadataTestSimple(BasicSharedKeyspaceUnitTestCase): def setUp(self): if CASS_SERVER_VERSION < (3, 0): @@ -2183,37 +2219,37 @@ def test_create_view_metadata(self): self.assertIsNotNone(score_table.columns['score']) # Validate basic mv information - self.assertEquals(mv.keyspace_name, self.keyspace_name) - self.assertEquals(mv.name, "monthlyhigh") - self.assertEquals(mv.base_table_name, "scores") + self.assertEqual(mv.keyspace_name, self.keyspace_name) + self.assertEqual(mv.name, "monthlyhigh") + self.assertEqual(mv.base_table_name, "scores") self.assertFalse(mv.include_all_columns) # Validate that all columns are preset and correct mv_columns = list(mv.columns.values()) - self.assertEquals(len(mv_columns), 6) + self.assertEqual(len(mv_columns), 6) game_column = mv_columns[0] self.assertIsNotNone(game_column) - self.assertEquals(game_column.name, 'game') - self.assertEquals(game_column, mv.partition_key[0]) + self.assertEqual(game_column.name, 'game') + self.assertEqual(game_column, mv.partition_key[0]) year_column = mv_columns[1] self.assertIsNotNone(year_column) - self.assertEquals(year_column.name, 'year') - self.assertEquals(year_column, mv.partition_key[1]) + self.assertEqual(year_column.name, 'year') + self.assertEqual(year_column, mv.partition_key[1]) month_column = mv_columns[2] self.assertIsNotNone(month_column) - self.assertEquals(month_column.name, 'month') - self.assertEquals(month_column, mv.partition_key[2]) + self.assertEqual(month_column.name, 'month') + self.assertEqual(month_column, mv.partition_key[2]) def compare_columns(a, b, name): - self.assertEquals(a.name, name) - self.assertEquals(a.name, b.name) - self.assertEquals(a.table, b.table) - self.assertEquals(a.cql_type, b.cql_type) - self.assertEquals(a.is_static, b.is_static) - self.assertEquals(a.is_reversed, b.is_reversed) + self.assertEqual(a.name, name) + self.assertEqual(a.name, b.name) + self.assertEqual(a.table, b.table) + self.assertEqual(a.cql_type, b.cql_type) + self.assertEqual(a.is_static, b.is_static) + self.assertEqual(a.is_reversed, b.is_reversed) score_column = mv_columns[3] compare_columns(score_column, mv.clustering_key[0], 'score') @@ -2290,7 +2326,7 @@ def test_base_table_column_addition_mv(self): self.assertIn("fouls", mv_alltime.columns) mv_alltime_fouls_comumn = self.cluster.metadata.keyspaces[self.keyspace_name].views["alltimehigh"].columns['fouls'] - self.assertEquals(mv_alltime_fouls_comumn.cql_type, 'int') + self.assertEqual(mv_alltime_fouls_comumn.cql_type, 'int') def test_base_table_type_alter_mv(self): """ @@ -2331,7 +2367,7 @@ def test_base_table_type_alter_mv(self): self.assertEqual(len(self.cluster.metadata.keyspaces[self.keyspace_name].views), 1) score_column = self.cluster.metadata.keyspaces[self.keyspace_name].tables['scores'].columns['score'] - self.assertEquals(score_column.cql_type, 'blob') + self.assertEqual(score_column.cql_type, 'blob') # until CASSANDRA-9920+CASSANDRA-10500 MV updates are only available later with an async event for i in range(10): @@ -2340,7 +2376,7 @@ def test_base_table_type_alter_mv(self): break time.sleep(.2) - self.assertEquals(score_mv_column.cql_type, 'blob') + self.assertEqual(score_mv_column.cql_type, 'blob') def test_metadata_with_quoted_identifiers(self): """ @@ -2393,31 +2429,31 @@ def test_metadata_with_quoted_identifiers(self): self.assertIsNotNone(t1_table.columns['the Value']) # Validate basic mv information - self.assertEquals(mv.keyspace_name, self.keyspace_name) - self.assertEquals(mv.name, "mv1") - self.assertEquals(mv.base_table_name, "t1") + self.assertEqual(mv.keyspace_name, self.keyspace_name) + self.assertEqual(mv.name, "mv1") + self.assertEqual(mv.base_table_name, "t1") self.assertFalse(mv.include_all_columns) # Validate that all columns are preset and correct mv_columns = list(mv.columns.values()) - self.assertEquals(len(mv_columns), 3) + self.assertEqual(len(mv_columns), 3) theKey_column = mv_columns[0] self.assertIsNotNone(theKey_column) - self.assertEquals(theKey_column.name, 'theKey') - self.assertEquals(theKey_column, mv.partition_key[0]) + self.assertEqual(theKey_column.name, 'theKey') + self.assertEqual(theKey_column, mv.partition_key[0]) cluster_column = mv_columns[1] self.assertIsNotNone(cluster_column) - self.assertEquals(cluster_column.name, 'the;Clustering') - self.assertEquals(cluster_column.name, mv.clustering_key[0].name) - self.assertEquals(cluster_column.table, mv.clustering_key[0].table) - self.assertEquals(cluster_column.is_static, mv.clustering_key[0].is_static) - self.assertEquals(cluster_column.is_reversed, mv.clustering_key[0].is_reversed) + self.assertEqual(cluster_column.name, 'the;Clustering') + self.assertEqual(cluster_column.name, mv.clustering_key[0].name) + self.assertEqual(cluster_column.table, mv.clustering_key[0].table) + self.assertEqual(cluster_column.is_static, mv.clustering_key[0].is_static) + self.assertEqual(cluster_column.is_reversed, mv.clustering_key[0].is_reversed) value_column = mv_columns[2] self.assertIsNotNone(value_column) - self.assertEquals(value_column.name, 'the Value') + self.assertEqual(value_column.name, 'the Value') @dseonly diff --git a/tests/integration/standard/test_metrics.py b/tests/integration/standard/test_metrics.py index 13758b65ad..18f35c15f1 100644 --- a/tests/integration/standard/test_metrics.py +++ b/tests/integration/standard/test_metrics.py @@ -14,6 +14,8 @@ import time +from cassandra.policies import WhiteListRoundRobinPolicy, FallthroughRetryPolicy + try: import unittest2 as unittest except ImportError: @@ -24,7 +26,8 @@ from cassandra.cluster import Cluster, NoHostAvailable from tests.integration import get_cluster, get_node, use_singledc, PROTOCOL_VERSION, execute_until_pass - +from greplin import scales +from tests.integration import BasicSharedKeyspaceUnitTestCaseWTable def setup_module(): use_singledc() @@ -33,8 +36,11 @@ def setup_module(): class MetricsTests(unittest.TestCase): def setUp(self): - self.cluster = Cluster(metrics_enabled=True, protocol_version=PROTOCOL_VERSION) - self.session = self.cluster.connect("test3rf") + contact_point = ['127.0.0.2'] + self.cluster = Cluster(contact_points=contact_point, metrics_enabled=True, protocol_version=PROTOCOL_VERSION, + load_balancing_policy=WhiteListRoundRobinPolicy(contact_point), + default_retry_policy=FallthroughRetryPolicy()) + self.session = self.cluster.connect("test3rf", wait_for_all_pools=True) def tearDown(self): self.cluster.shutdown() @@ -44,8 +50,6 @@ def test_connection_error(self): Trigger and ensure connection_errors are counted Stop all node with the driver knowing about the "DOWN" states. """ - - # Test writes for i in range(0, 100): self.session.execute_async("INSERT INTO test (k, v) VALUES ({0}, {1})".format(i, i)) @@ -145,13 +149,13 @@ def test_unavailable(self): query = SimpleStatement("INSERT INTO test (k, v) VALUES (2, 2)", consistency_level=ConsistencyLevel.ALL) with self.assertRaises(Unavailable): self.session.execute(query) - self.assertEqual(2, self.cluster.metrics.stats.unavailables) + self.assertEqual(self.cluster.metrics.stats.unavailables, 1) # Test write query = SimpleStatement("SELECT * FROM test", consistency_level=ConsistencyLevel.ALL) with self.assertRaises(Unavailable): self.session.execute(query, timeout=None) - self.assertEqual(4, self.cluster.metrics.stats.unavailables) + self.assertEqual(self.cluster.metrics.stats.unavailables, 2) finally: get_node(1).start(wait_other_notice=True, wait_for_binary_proto=True) # Give some time for the cluster to come back up, for the next test @@ -170,3 +174,102 @@ def test_unavailable(self): # def test_retry(self): # # TODO: Look for ways to generate retries # pass + + +class MetricsNamespaceTest(BasicSharedKeyspaceUnitTestCaseWTable): + + def test_metrics_per_cluster(self): + """ + Test to validate that metrics can be scopped to invdividual clusters + @since 3.6.0 + @jira_ticket PYTHON-561 + @expected_result metrics should be scopped to a cluster level + + @test_category metrics + """ + + cluster2 = Cluster(metrics_enabled=True, protocol_version=PROTOCOL_VERSION, + default_retry_policy=FallthroughRetryPolicy()) + cluster2.connect(self.ks_name, wait_for_all_pools=True) + + query = SimpleStatement("SELECT * FROM {0}.{0}".format(self.ks_name), consistency_level=ConsistencyLevel.ALL) + self.session.execute(query) + + # Pause node so it shows as unreachable to coordinator + get_node(1).pause() + + try: + # Test write + query = SimpleStatement("INSERT INTO {0}.{0} (k, v) VALUES (2, 2)".format(self.ks_name), consistency_level=ConsistencyLevel.ALL) + with self.assertRaises(WriteTimeout): + self.session.execute(query, timeout=None) + finally: + get_node(1).resume() + + # Change the scales stats_name of the cluster2 + cluster2.metrics.set_stats_name('cluster2-metrics') + + stats_cluster1 = self.cluster.metrics.get_stats() + stats_cluster2 = cluster2.metrics.get_stats() + + # Test direct access to stats + self.assertEqual(1, self.cluster.metrics.stats.write_timeouts) + self.assertEqual(0, cluster2.metrics.stats.write_timeouts) + + # Test direct access to a child stats + self.assertNotEqual(0.0, self.cluster.metrics.request_timer['mean']) + self.assertEqual(0.0, cluster2.metrics.request_timer['mean']) + + # Test access via metrics.get_stats() + self.assertNotEqual(0.0, stats_cluster1['request_timer']['mean']) + self.assertEqual(0.0, stats_cluster2['request_timer']['mean']) + + # Test access by stats_name + self.assertEqual(0.0, scales.getStats()['cluster2-metrics']['request_timer']['mean']) + + cluster2.shutdown() + + def test_duplicate_metrics_per_cluster(self): + """ + Test to validate that cluster metrics names can't overlap. + @since 3.6.0 + @jira_ticket PYTHON-561 + @expected_result metric names should not be allowed to be same. + + @test_category metrics + """ + cluster2 = Cluster(metrics_enabled=True, protocol_version=PROTOCOL_VERSION, + default_retry_policy=FallthroughRetryPolicy()) + + cluster3 = Cluster(metrics_enabled=True, protocol_version=PROTOCOL_VERSION, + default_retry_policy=FallthroughRetryPolicy()) + + # Ensure duplicate metric names are not allowed + cluster2.metrics.set_stats_name("appcluster") + cluster2.metrics.set_stats_name("appcluster") + with self.assertRaises(ValueError): + cluster3.metrics.set_stats_name("appcluster") + cluster3.metrics.set_stats_name("devops") + + session2 = cluster2.connect(self.ks_name, wait_for_all_pools=True) + session3 = cluster3.connect(self.ks_name, wait_for_all_pools=True) + + # Basic validation that naming metrics doesn't impact their segration or accuracy + for i in range(10): + query = SimpleStatement("SELECT * FROM {0}.{0}".format(self.ks_name), consistency_level=ConsistencyLevel.ALL) + session2.execute(query) + + for i in range(5): + query = SimpleStatement("SELECT * FROM {0}.{0}".format(self.ks_name), consistency_level=ConsistencyLevel.ALL) + session3.execute(query) + + self.assertEqual(cluster2.metrics.get_stats()['request_timer']['count'], 10) + self.assertEqual(cluster3.metrics.get_stats()['request_timer']['count'], 5) + + # Check scales to ensure they are appropriately named + self.assertTrue("appcluster" in scales._Stats.stats.keys()) + self.assertTrue("devops" in scales._Stats.stats.keys()) + + + + diff --git a/tests/integration/standard/test_query.py b/tests/integration/standard/test_query.py index 4bd742bc42..719f2b1fc9 100644 --- a/tests/integration/standard/test_query.py +++ b/tests/integration/standard/test_query.py @@ -26,7 +26,7 @@ from cassandra.cluster import Cluster, NoHostAvailable from cassandra.policies import HostDistance, RoundRobinPolicy -from tests.integration import use_singledc, PROTOCOL_VERSION, BasicSharedKeyspaceUnitTestCase, get_server_versions, greaterthanprotocolv3, MockLoggingHandler +from tests.integration import use_singledc, PROTOCOL_VERSION, BasicSharedKeyspaceUnitTestCase, get_server_versions, greaterthanprotocolv3, MockLoggingHandler, get_supported_protocol_versions import time import re @@ -191,7 +191,7 @@ def test_incomplete_query_trace(self): self.assertTrue(self._wait_for_trace_to_populate(trace.trace_id)) # Delete trace duration from the session (this is what the driver polls for "complete") - delete_statement = SimpleStatement("DELETE duration FROM system_traces.sessions WHERE session_id = {}".format(trace.trace_id), consistency_level=ConsistencyLevel.ALL) + delete_statement = SimpleStatement("DELETE duration FROM system_traces.sessions WHERE session_id = {0}".format(trace.trace_id), consistency_level=ConsistencyLevel.ALL) self.session.execute(delete_statement) self.assertTrue(self._wait_for_trace_to_delete(trace.trace_id)) @@ -225,7 +225,7 @@ def _wait_for_trace_to_delete(self, trace_id): return count != retry_max def _is_trace_present(self, trace_id): - select_statement = SimpleStatement("SElECT duration FROM system_traces.sessions WHERE session_id = {}".format(trace_id), consistency_level=ConsistencyLevel.ALL) + select_statement = SimpleStatement("SElECT duration FROM system_traces.sessions WHERE session_id = {0}".format(trace_id), consistency_level=ConsistencyLevel.ALL) ssrs = self.session.execute(select_statement) if(ssrs[0].duration is None): return False @@ -356,6 +356,39 @@ def make_query_plan(self, working_keyspace=None, query=None): return list(self._live_hosts) +class PreparedStatementMetdataTest(unittest.TestCase): + + def test_prepared_metadata_generation(self): + """ + Test to validate that result metadata is appropriately populated across protocol version + + In protocol version 1 result metadata is retrieved everytime the statement is issued. In all + other protocol versions it's set once upon the prepare, then re-used. This test ensures that it manifests + it's self the same across multiple protocol versions. + + @since 3.6.0 + @jira_ticket PYTHON-71 + @expected_result result metadata is consistent. + """ + + base_line = None + for proto_version in get_supported_protocol_versions(): + cluster = Cluster(protocol_version=proto_version) + session = cluster.connect() + select_statement = session.prepare("SELECT * FROM system.local") + if proto_version == 1: + self.assertEqual(select_statement.result_metadata, None) + else: + self.assertNotEqual(select_statement.result_metadata, None) + future = session.execute_async(select_statement) + results = future.result() + if base_line is None: + base_line = results[0].__dict__.keys() + else: + self.assertEqual(base_line, results[0].__dict__.keys()) + cluster.shutdown() + + class PreparedStatementArgTest(unittest.TestCase): def test_prepare_on_all_hosts(self): @@ -881,73 +914,73 @@ def test_mv_filtering(self): query_statement = SimpleStatement("SELECT * FROM {0}.alltimehigh WHERE game='Coup'".format(self.keyspace_name), consistency_level=ConsistencyLevel.QUORUM) results = self.session.execute(query_statement) - self.assertEquals(results[0].game, 'Coup') - self.assertEquals(results[0].year, 2015) - self.assertEquals(results[0].month, 5) - self.assertEquals(results[0].day, 1) - self.assertEquals(results[0].score, 4000) - self.assertEquals(results[0].user, "pcmanus") + self.assertEqual(results[0].game, 'Coup') + self.assertEqual(results[0].year, 2015) + self.assertEqual(results[0].month, 5) + self.assertEqual(results[0].day, 1) + self.assertEqual(results[0].score, 4000) + self.assertEqual(results[0].user, "pcmanus") # Test prepared statement and daily high filtering prepared_query = self.session.prepare("SELECT * FROM {0}.dailyhigh WHERE game=? AND year=? AND month=? and day=?".format(self.keyspace_name)) bound_query = prepared_query.bind(("Coup", 2015, 6, 2)) results = self.session.execute(bound_query) - self.assertEquals(results[0].game, 'Coup') - self.assertEquals(results[0].year, 2015) - self.assertEquals(results[0].month, 6) - self.assertEquals(results[0].day, 2) - self.assertEquals(results[0].score, 2000) - self.assertEquals(results[0].user, "pcmanus") - - self.assertEquals(results[1].game, 'Coup') - self.assertEquals(results[1].year, 2015) - self.assertEquals(results[1].month, 6) - self.assertEquals(results[1].day, 2) - self.assertEquals(results[1].score, 1000) - self.assertEquals(results[1].user, "tjake") + self.assertEqual(results[0].game, 'Coup') + self.assertEqual(results[0].year, 2015) + self.assertEqual(results[0].month, 6) + self.assertEqual(results[0].day, 2) + self.assertEqual(results[0].score, 2000) + self.assertEqual(results[0].user, "pcmanus") + + self.assertEqual(results[1].game, 'Coup') + self.assertEqual(results[1].year, 2015) + self.assertEqual(results[1].month, 6) + self.assertEqual(results[1].day, 2) + self.assertEqual(results[1].score, 1000) + self.assertEqual(results[1].user, "tjake") # Test montly high range queries prepared_query = self.session.prepare("SELECT * FROM {0}.monthlyhigh WHERE game=? AND year=? AND month=? and score >= ? and score <= ?".format(self.keyspace_name)) bound_query = prepared_query.bind(("Coup", 2015, 6, 2500, 3500)) results = self.session.execute(bound_query) - self.assertEquals(results[0].game, 'Coup') - self.assertEquals(results[0].year, 2015) - self.assertEquals(results[0].month, 6) - self.assertEquals(results[0].day, 20) - self.assertEquals(results[0].score, 3500) - self.assertEquals(results[0].user, "jbellis") - - self.assertEquals(results[1].game, 'Coup') - self.assertEquals(results[1].year, 2015) - self.assertEquals(results[1].month, 6) - self.assertEquals(results[1].day, 9) - self.assertEquals(results[1].score, 2700) - self.assertEquals(results[1].user, "jmckenzie") - - self.assertEquals(results[2].game, 'Coup') - self.assertEquals(results[2].year, 2015) - self.assertEquals(results[2].month, 6) - self.assertEquals(results[2].day, 1) - self.assertEquals(results[2].score, 2500) - self.assertEquals(results[2].user, "iamaleksey") + self.assertEqual(results[0].game, 'Coup') + self.assertEqual(results[0].year, 2015) + self.assertEqual(results[0].month, 6) + self.assertEqual(results[0].day, 20) + self.assertEqual(results[0].score, 3500) + self.assertEqual(results[0].user, "jbellis") + + self.assertEqual(results[1].game, 'Coup') + self.assertEqual(results[1].year, 2015) + self.assertEqual(results[1].month, 6) + self.assertEqual(results[1].day, 9) + self.assertEqual(results[1].score, 2700) + self.assertEqual(results[1].user, "jmckenzie") + + self.assertEqual(results[2].game, 'Coup') + self.assertEqual(results[2].year, 2015) + self.assertEqual(results[2].month, 6) + self.assertEqual(results[2].day, 1) + self.assertEqual(results[2].score, 2500) + self.assertEqual(results[2].user, "iamaleksey") # Test filtered user high scores query_statement = SimpleStatement("SELECT * FROM {0}.filtereduserhigh WHERE game='Chess'".format(self.keyspace_name), consistency_level=ConsistencyLevel.QUORUM) results = self.session.execute(query_statement) - self.assertEquals(results[0].game, 'Chess') - self.assertEquals(results[0].year, 2015) - self.assertEquals(results[0].month, 6) - self.assertEquals(results[0].day, 21) - self.assertEquals(results[0].score, 3500) - self.assertEquals(results[0].user, "jbellis") - - self.assertEquals(results[1].game, 'Chess') - self.assertEquals(results[1].year, 2015) - self.assertEquals(results[1].month, 1) - self.assertEquals(results[1].day, 25) - self.assertEquals(results[1].score, 3200) - self.assertEquals(results[1].user, "pcmanus") + self.assertEqual(results[0].game, 'Chess') + self.assertEqual(results[0].year, 2015) + self.assertEqual(results[0].month, 6) + self.assertEqual(results[0].day, 21) + self.assertEqual(results[0].score, 3500) + self.assertEqual(results[0].user, "jbellis") + + self.assertEqual(results[1].game, 'Chess') + self.assertEqual(results[1].year, 2015) + self.assertEqual(results[1].month, 1) + self.assertEqual(results[1].day, 25) + self.assertEqual(results[1].score, 3200) + self.assertEqual(results[1].user, "pcmanus") class UnicodeQueryTest(BasicSharedKeyspaceUnitTestCase): diff --git a/tests/integration/standard/test_types.py b/tests/integration/standard/test_types.py index 736e7957e2..f959d4d9f9 100644 --- a/tests/integration/standard/test_types.py +++ b/tests/integration/standard/test_types.py @@ -31,7 +31,7 @@ from tests.unit.cython.utils import cythontest from tests.integration import use_singledc, PROTOCOL_VERSION, execute_until_pass, notprotocolv1, \ - BasicSharedKeyspaceUnitTestCase, greaterthancass20, lessthancass30 + BasicSharedKeyspaceUnitTestCase, greaterthancass21, lessthancass30 from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES, COLLECTION_TYPES, \ get_sample, get_collection_sample @@ -796,7 +796,7 @@ def test_cython_decimal(self): class TypeTestsProtocol(BasicSharedKeyspaceUnitTestCase): - @greaterthancass20 + @greaterthancass21 @lessthancass30 def test_nested_types_with_protocol_version(self): """ diff --git a/tests/integration/standard/utils.py b/tests/integration/standard/utils.py index 4011047fc8..917b3a7f6e 100644 --- a/tests/integration/standard/utils.py +++ b/tests/integration/standard/utils.py @@ -4,6 +4,7 @@ from tests.integration.datatype_utils import PRIMITIVE_DATATYPES, get_sample + def create_table_with_all_types(table_name, session, N): """ Method that given a table_name and session construct a table that contains diff --git a/tests/unit/cqlengine/__init__.py b/tests/unit/cqlengine/__init__.py new file mode 100644 index 0000000000..87fc3685e0 --- /dev/null +++ b/tests/unit/cqlengine/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2013-2016 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/unit/cqlengine/test_columns.py b/tests/unit/cqlengine/test_columns.py new file mode 100644 index 0000000000..181c103515 --- /dev/null +++ b/tests/unit/cqlengine/test_columns.py @@ -0,0 +1,68 @@ +# Copyright 2013-2016 DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import unittest2 as unittest +except ImportError: + import unittest # noqa + +from cassandra.cqlengine.columns import Column + + +class ColumnTest(unittest.TestCase): + + def test_comparisons(self): + c0 = Column() + c1 = Column() + self.assertEqual(c1.position - c0.position, 1) + + # __ne__ + self.assertNotEqual(c0, c1) + self.assertNotEqual(c0, object()) + + # __eq__ + self.assertEqual(c0, c0) + self.assertFalse(c0 == object()) + + # __lt__ + self.assertLess(c0, c1) + try: + c0 < object() # this raises for Python 3 + except TypeError: + pass + + # __le__ + self.assertLessEqual(c0, c1) + self.assertLessEqual(c0, c0) + try: + c0 <= object() # this raises for Python 3 + except TypeError: + pass + + # __gt__ + self.assertGreater(c1, c0) + try: + c1 > object() # this raises for Python 3 + except TypeError: + pass + + # __ge__ + self.assertGreaterEqual(c1, c0) + self.assertGreaterEqual(c1, c1) + try: + c1 >= object() # this raises for Python 3 + except TypeError: + pass + + diff --git a/tests/unit/test_concurrent.py b/tests/unit/test_concurrent.py index a535bf2260..948f6f2502 100644 --- a/tests/unit/test_concurrent.py +++ b/tests/unit/test_concurrent.py @@ -17,12 +17,15 @@ import unittest2 as unittest except ImportError: import unittest # noqa + from itertools import cycle from mock import Mock import time import threading from six.moves.queue import PriorityQueue +import sys +from cassandra.cluster import Cluster, Session from cassandra.concurrent import execute_concurrent, execute_concurrent_with_args @@ -229,3 +232,19 @@ def validate_result_ordering(self, results): current_time_added = list(result)[0] self.assertLess(last_time_added, current_time_added) last_time_added = current_time_added + + def test_recursion_limited(self): + """ + Verify that recursion is controlled when raise_on_first_error=False and something is wrong with the query. + + PYTHON-585 + """ + max_recursion = sys.getrecursionlimit() + s = Session(Cluster(), []) + self.assertRaises(TypeError, execute_concurrent_with_args, s, "doesn't matter", [('param',)] * max_recursion, raise_on_first_error=True) + + results = execute_concurrent_with_args(s, "doesn't matter", [('param',)] * max_recursion, raise_on_first_error=False) # previously + self.assertEqual(len(results), max_recursion) + for r in results: + self.assertFalse(r[0]) + self.assertIsInstance(r[1], TypeError) diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py index 2ac10a590f..b8cb640b46 100644 --- a/tests/unit/test_connection.py +++ b/tests/unit/test_connection.py @@ -112,7 +112,7 @@ def test_negative_body_length(self, *args): def test_unsupported_cql_version(self, *args): c = self.make_connection() - c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)} + c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message, [])} c.defunct = Mock() c.cql_version = "3.0.3" @@ -135,7 +135,7 @@ def test_unsupported_cql_version(self, *args): def test_prefer_lz4_compression(self, *args): c = self.make_connection() - c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)} + c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message, [])} c.defunct = Mock() c.cql_version = "3.0.3" @@ -158,7 +158,7 @@ def test_prefer_lz4_compression(self, *args): def test_requested_compression_not_available(self, *args): c = self.make_connection() - c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)} + c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message, [])} c.defunct = Mock() # request lz4 compression c.compression = "lz4" @@ -188,7 +188,7 @@ def test_requested_compression_not_available(self, *args): def test_use_requested_compression(self, *args): c = self.make_connection() - c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message)} + c._requests = {0: (c._handle_options_response, ProtocolHandler.decode_message, [])} c.defunct = Mock() # request snappy compression c.compression = "snappy" diff --git a/tests/unit/test_host_connection_pool.py b/tests/unit/test_host_connection_pool.py index fb0ca21711..5fe230f402 100644 --- a/tests/unit/test_host_connection_pool.py +++ b/tests/unit/test_host_connection_pool.py @@ -165,7 +165,8 @@ def test_spawn_when_at_max(self): def test_return_defunct_connection(self): host = Mock(spec=Host, address='ip1') session = self.make_session() - conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100) + conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, + max_request_id=100, signaled_error=False) session.cluster.connection_factory.return_value = conn pool = HostConnectionPool(host, HostDistance.LOCAL, session) @@ -177,14 +178,14 @@ def test_return_defunct_connection(self): pool.return_connection(conn) # the connection should be closed a new creation scheduled - conn.close.assert_called_once() - session.submit.assert_called_once() + self.assertTrue(session.submit.call_args) self.assertFalse(pool.is_shutdown) def test_return_defunct_connection_on_down_host(self): host = Mock(spec=Host, address='ip1') session = self.make_session() - conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, max_request_id=100, signaled_error=False) + conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False, + max_request_id=100, signaled_error=False) session.cluster.connection_factory.return_value = conn pool = HostConnectionPool(host, HostDistance.LOCAL, session) @@ -196,15 +197,15 @@ def test_return_defunct_connection_on_down_host(self): pool.return_connection(conn) # the connection should be closed a new creation scheduled - session.cluster.signal_connection_failure.assert_called_once() - conn.close.assert_called_once() + self.assertTrue(session.cluster.signal_connection_failure.call_args) + self.assertTrue(conn.close.call_args) self.assertFalse(session.submit.called) self.assertTrue(pool.is_shutdown) def test_return_closed_connection(self): host = Mock(spec=Host, address='ip1') session = self.make_session() - conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=True, max_request_id=100) + conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=True, max_request_id=100, signaled_error=False) session.cluster.connection_factory.return_value = conn pool = HostConnectionPool(host, HostDistance.LOCAL, session) @@ -216,7 +217,7 @@ def test_return_closed_connection(self): pool.return_connection(conn) # a new creation should be scheduled - session.submit.assert_called_once() + self.assertTrue(session.submit.call_args) self.assertFalse(pool.is_shutdown) def test_host_instantiations(self): diff --git a/tests/unit/test_parameter_binding.py b/tests/unit/test_parameter_binding.py index d48b5d9573..555dfe3834 100644 --- a/tests/unit/test_parameter_binding.py +++ b/tests/unit/test_parameter_binding.py @@ -91,7 +91,7 @@ def setUpClass(cls): routing_key_indexes=[1, 0], query=None, keyspace='keyspace', - protocol_version=cls.protocol_version) + protocol_version=cls.protocol_version, result_metadata=None) cls.bound = BoundStatement(prepared_statement=cls.prepared) def test_invalid_argument_type(self): @@ -130,7 +130,8 @@ def test_inherit_fetch_size(self): routing_key_indexes=[], query=None, keyspace=keyspace, - protocol_version=self.protocol_version) + protocol_version=self.protocol_version, + result_metadata=None) prepared_statement.fetch_size = 1234 bound_statement = BoundStatement(prepared_statement=prepared_statement) self.assertEqual(1234, bound_statement.fetch_size) @@ -163,7 +164,8 @@ def test_values_none(self): routing_key_indexes=[], query=None, keyspace='whatever', - protocol_version=self.protocol_version) + protocol_version=self.protocol_version, + result_metadata=None) bound = prepared_statement.bind(None) self.assertListEqual(bound.values, []) diff --git a/tests/unit/test_response_future.py b/tests/unit/test_response_future.py index ad5bb3e93b..88b08af878 100644 --- a/tests/unit/test_response_future.py +++ b/tests/unit/test_response_future.py @@ -67,7 +67,7 @@ def test_result_message(self): rf.session._pools.get.assert_called_once_with('ip1') pool.borrow_connection.assert_called_once_with(timeout=ANY) - connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message) + connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) rf._set_result(self.make_mock_response([{'col': 'val'}])) result = rf.result() @@ -192,7 +192,7 @@ def test_retry_policy_says_retry(self): rf.session._pools.get.assert_called_once_with('ip1') pool.borrow_connection.assert_called_once_with(timeout=ANY) - connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message) + connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) result = Mock(spec=UnavailableErrorMessage, info={}) rf._set_result(result) @@ -210,7 +210,7 @@ def test_retry_policy_says_retry(self): # an UnavailableException rf.session._pools.get.assert_called_with('ip1') pool.borrow_connection.assert_called_with(timeout=ANY) - connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message) + connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) def test_retry_with_different_host(self): session = self.make_session() @@ -225,7 +225,7 @@ def test_retry_with_different_host(self): rf.session._pools.get.assert_called_once_with('ip1') pool.borrow_connection.assert_called_once_with(timeout=ANY) - connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message) + connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) self.assertEqual(ConsistencyLevel.QUORUM, rf.message.consistency_level) result = Mock(spec=OverloadedErrorMessage, info={}) @@ -243,7 +243,7 @@ def test_retry_with_different_host(self): # it should try with a different host rf.session._pools.get.assert_called_with('ip2') pool.borrow_connection.assert_called_with(timeout=ANY) - connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message) + connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) # the consistency level should be the same self.assertEqual(ConsistencyLevel.QUORUM, rf.message.consistency_level) @@ -480,7 +480,7 @@ def test_prepared_query_not_found(self): result = Mock(spec=PreparedQueryNotFound, info='a' * 16) rf._set_result(result) - session.submit.assert_called_once() + self.assertTrue(session.submit.call_args) args, kwargs = session.submit.call_args self.assertEqual(rf._reprepare, args[-2]) self.assertIsInstance(args[-1], PrepareMessage) From 156fcbfd393578638eccf70951d85f34cca97695 Mon Sep 17 00:00:00 2001 From: Stefania Alborghetti Date: Tue, 4 Oct 2016 15:57:34 +0800 Subject: [PATCH 03/10] 12736: self.prepared_statement may be None for a batch of prepared statements --- cassandra/cluster.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 17ff0f32c9..26446e384f 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -3667,7 +3667,8 @@ def _execute_after_prepare(self, host, connection, pool, response): if response.kind == RESULT_KIND_PREPARED: # result metadata is the only thing that could have changed from an alter _, _, _, result_metadata = response.results - self.prepared_statement.result_metadata = result_metadata + if self.prepared_statement: + self.prepared_statement.result_metadata = result_metadata # use self._query to re-use the same host and # at the same time properly borrow the connection From b27c27a97c742f6f7633faff011d3789138eef1b Mon Sep 17 00:00:00 2001 From: Alan Boudreault Date: Thu, 15 Dec 2016 18:08:15 -0500 Subject: [PATCH 04/10] reverted PR 680 until the patch is merged in cassandra --- cassandra/protocol.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/cassandra/protocol.py b/cassandra/protocol.py index 32e192ee6d..d50e4d8868 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -30,7 +30,7 @@ UserAggregateDescriptor, SchemaTargetType) from cassandra.marshal import (int32_pack, int32_unpack, uint16_pack, uint16_unpack, int8_pack, int8_unpack, uint64_pack, header_pack, - v3_header_pack, uint32_pack) + v3_header_pack) from cassandra.cqltypes import (AsciiType, BytesType, BooleanType, CounterColumnType, DateType, DecimalType, DoubleType, FloatType, Int32Type, @@ -561,7 +561,7 @@ def send_body(self, f, protocol_version): flags |= _PROTOCOL_TIMESTAMP if protocol_version >= 5: - write_uint(f, flags) + write_int(f, flags) else: write_byte(f, flags) @@ -775,9 +775,6 @@ def __init__(self, query): def send_body(self, f, protocol_version): write_longstring(f, self.query) - if protocol_version >= 5: - # Write the flags byte; with 0 value for now, but this should change in PYTHON-678 - write_uint(f, 0) class ExecuteMessage(_MessageType): @@ -832,7 +829,7 @@ def send_body(self, f, protocol_version): flags |= _SKIP_METADATA_FLAG if protocol_version >= 5: - write_uint(f, flags) + write_int(f, flags) else: write_byte(f, flags) @@ -1167,10 +1164,6 @@ def write_int(f, i): f.write(int32_pack(i)) -def write_uint(f, i): - f.write(uint32_pack(i)) - - def write_long(f, i): f.write(uint64_pack(i)) From 5a6bad98ed8ff09471a8cfb2db735dae2aaacf9c Mon Sep 17 00:00:00 2001 From: Tyler Hobbs Date: Mon, 19 Dec 2016 16:16:12 -0600 Subject: [PATCH 05/10] Revert "reverted PR 680 until the patch is merged in cassandra" This reverts commit b27c27a97c742f6f7633faff011d3789138eef1b. --- cassandra/protocol.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/cassandra/protocol.py b/cassandra/protocol.py index d50e4d8868..32e192ee6d 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -30,7 +30,7 @@ UserAggregateDescriptor, SchemaTargetType) from cassandra.marshal import (int32_pack, int32_unpack, uint16_pack, uint16_unpack, int8_pack, int8_unpack, uint64_pack, header_pack, - v3_header_pack) + v3_header_pack, uint32_pack) from cassandra.cqltypes import (AsciiType, BytesType, BooleanType, CounterColumnType, DateType, DecimalType, DoubleType, FloatType, Int32Type, @@ -561,7 +561,7 @@ def send_body(self, f, protocol_version): flags |= _PROTOCOL_TIMESTAMP if protocol_version >= 5: - write_int(f, flags) + write_uint(f, flags) else: write_byte(f, flags) @@ -775,6 +775,9 @@ def __init__(self, query): def send_body(self, f, protocol_version): write_longstring(f, self.query) + if protocol_version >= 5: + # Write the flags byte; with 0 value for now, but this should change in PYTHON-678 + write_uint(f, 0) class ExecuteMessage(_MessageType): @@ -829,7 +832,7 @@ def send_body(self, f, protocol_version): flags |= _SKIP_METADATA_FLAG if protocol_version >= 5: - write_int(f, flags) + write_uint(f, flags) else: write_byte(f, flags) @@ -1164,6 +1167,10 @@ def write_int(f, i): f.write(int32_pack(i)) +def write_uint(f, i): + f.write(uint32_pack(i)) + + def write_long(f, i): f.write(uint64_pack(i)) From bab71d32b6ad69f126b0894a6c3537fe8c7de045 Mon Sep 17 00:00:00 2001 From: Tyler Hobbs Date: Mon, 19 Dec 2016 16:44:40 -0600 Subject: [PATCH 06/10] Revert "Revert "reverted PR 680 until the patch is merged in cassandra"" This reverts commit 5a6bad98ed8ff09471a8cfb2db735dae2aaacf9c. --- cassandra/protocol.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/cassandra/protocol.py b/cassandra/protocol.py index 32e192ee6d..d50e4d8868 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -30,7 +30,7 @@ UserAggregateDescriptor, SchemaTargetType) from cassandra.marshal import (int32_pack, int32_unpack, uint16_pack, uint16_unpack, int8_pack, int8_unpack, uint64_pack, header_pack, - v3_header_pack, uint32_pack) + v3_header_pack) from cassandra.cqltypes import (AsciiType, BytesType, BooleanType, CounterColumnType, DateType, DecimalType, DoubleType, FloatType, Int32Type, @@ -561,7 +561,7 @@ def send_body(self, f, protocol_version): flags |= _PROTOCOL_TIMESTAMP if protocol_version >= 5: - write_uint(f, flags) + write_int(f, flags) else: write_byte(f, flags) @@ -775,9 +775,6 @@ def __init__(self, query): def send_body(self, f, protocol_version): write_longstring(f, self.query) - if protocol_version >= 5: - # Write the flags byte; with 0 value for now, but this should change in PYTHON-678 - write_uint(f, 0) class ExecuteMessage(_MessageType): @@ -832,7 +829,7 @@ def send_body(self, f, protocol_version): flags |= _SKIP_METADATA_FLAG if protocol_version >= 5: - write_uint(f, flags) + write_int(f, flags) else: write_byte(f, flags) @@ -1167,10 +1164,6 @@ def write_int(f, i): f.write(int32_pack(i)) -def write_uint(f, i): - f.write(uint32_pack(i)) - - def write_long(f, i): f.write(uint64_pack(i)) From b2d975841ad6e7413b10335432e7e8fec6cd2a8b Mon Sep 17 00:00:00 2001 From: Jim Witschey Date: Thu, 9 Mar 2017 12:43:27 -0500 Subject: [PATCH 07/10] finish merge from integration-tested branch --- cassandra/cluster.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 2d39d18400..26c36d9169 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -3687,8 +3687,7 @@ def _execute_after_prepare(self, host, connection, pool, response): if response.kind == RESULT_KIND_PREPARED: # result metadata is the only thing that could have changed from an alter _, _, _, result_metadata = response.results - if self.prepared_statement: - self.prepared_statement.result_metadata = result_metadata + self.prepared_statement.result_metadata = result_metadata # use self._query to re-use the same host and # at the same time properly borrow the connection From 7c5064409924972d783233ba181e57ab6f31beab Mon Sep 17 00:00:00 2001 From: Jim Witschey Date: Mon, 24 Jul 2017 16:52:29 -0400 Subject: [PATCH 08/10] update version post-3.11.0-release --- cassandra/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cassandra/__init__.py b/cassandra/__init__.py index d7f3aeca44..feabc14192 100644 --- a/cassandra/__init__.py +++ b/cassandra/__init__.py @@ -22,7 +22,7 @@ def emit(self, record): logging.getLogger('cassandra').addHandler(NullHandler()) -__version_info__ = (3, 11, 0) +__version_info__ = (3, 11, 0, 'post0') __version__ = '.'.join(map(str, __version_info__)) From 12c5035a35852d43373dc2ade817974865c38a45 Mon Sep 17 00:00:00 2001 From: Adam Holmberg Date: Tue, 23 Feb 2021 10:07:53 -0600 Subject: [PATCH 09/10] Revert "Merge branch 'master' into cassandra-test" This reverts commit e62a2c8f6138d7341df54e56a49a67b0b0e8c1de, reversing changes made to 2508df96053a0f5caf0f8c10e549538084610970. --- cassandra/__init__.py | 11 ++---- cassandra/cluster.py | 14 ++------ cassandra/connection.py | 11 ++---- cassandra/protocol.py | 4 --- tests/integration/__init__.py | 9 +++-- .../simulacron/test_empty_column.py | 4 +-- tests/integration/standard/test_cluster.py | 36 +++++++------------ tests/unit/test_cluster.py | 2 -- 8 files changed, 26 insertions(+), 65 deletions(-) diff --git a/cassandra/__init__.py b/cassandra/__init__.py index 1e16bca287..100df2df17 100644 --- a/cassandra/__init__.py +++ b/cassandra/__init__.py @@ -161,12 +161,7 @@ class ProtocolVersion(object): V5 = 5 """ - v5, in beta from 3.x+. Finalised in 4.0-beta5 - """ - - V6 = 6 - """ - v6, in beta from 4.0-beta5 + v5, in beta from 3.x+ """ DSE_V1 = 0x41 @@ -179,12 +174,12 @@ class ProtocolVersion(object): DSE private protocol v2, supported in DSE 6.0+ """ - SUPPORTED_VERSIONS = (DSE_V2, DSE_V1, V6, V5, V4, V3, V2, V1) + SUPPORTED_VERSIONS = (DSE_V2, DSE_V1, V5, V4, V3, V2, V1) """ A tuple of all supported protocol versions """ - BETA_VERSIONS = (V6,) + BETA_VERSIONS = (V5,) """ A tuple of all beta protocol versions """ diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 7e101afba8..cedcf8207b 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -63,7 +63,7 @@ BatchMessage, RESULT_KIND_PREPARED, RESULT_KIND_SET_KEYSPACE, RESULT_KIND_ROWS, RESULT_KIND_SCHEMA_CHANGE, ProtocolHandler, - RESULT_KIND_VOID, ProtocolException) + RESULT_KIND_VOID) from cassandra.metadata import Metadata, protect_name, murmur3, _NodeInfo from cassandra.policies import (TokenAwarePolicy, DCAwareRoundRobinPolicy, SimpleConvictionPolicy, ExponentialReconnectionPolicy, HostDistance, @@ -1570,7 +1570,7 @@ def set_core_connections_per_host(self, host_distance, core_connections): If :attr:`~.Cluster.protocol_version` is set to 3 or higher, this is not supported (there is always one connection per host, unless the host is remote and :attr:`connect_to_remote_hosts` is :const:`False`) - and using this will result in an :exc:`~.UnsupportedOperation`. + and using this will result in an :exc:`~.UnsupporteOperation`. """ if self.protocol_version >= 3: raise UnsupportedOperation( @@ -1603,7 +1603,7 @@ def set_max_connections_per_host(self, host_distance, max_connections): If :attr:`~.Cluster.protocol_version` is set to 3 or higher, this is not supported (there is always one connection per host, unless the host is remote and :attr:`connect_to_remote_hosts` is :const:`False`) - and using this will result in an :exc:`~.UnsupportedOperation`. + and using this will result in an :exc:`~.UnsupporteOperation`. """ if self.protocol_version >= 3: raise UnsupportedOperation( @@ -3548,14 +3548,6 @@ def _try_connect(self, host): break except ProtocolVersionUnsupported as e: self._cluster.protocol_downgrade(host.endpoint, e.startup_version) - except ProtocolException as e: - # protocol v5 is out of beta in C* >=4.0-beta5 and is now the default driver - # protocol version. If the protocol version was not explicitly specified, - # and that the server raises a beta protocol error, we should downgrade. - if not self._cluster._protocol_version_explicit and e.is_beta_protocol_error: - self._cluster.protocol_downgrade(host.endpoint, self._cluster.protocol_version) - else: - raise log.debug("[control connection] Established new connection %r, " "registering watchers and refreshing schema and topology", diff --git a/cassandra/connection.py b/cassandra/connection.py index 48b3caefed..477eaf2f28 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -897,10 +897,6 @@ def _connect_socket(self): for args in self.sockopts: self._socket.setsockopt(*args) - def _enable_compression(self): - if self._compressor: - self.compressor = self._compressor - def _enable_checksumming(self): self._io_buffer.set_checksumming_buffer() self._is_checksumming_enabled = True @@ -1332,7 +1328,8 @@ def _handle_startup_response(self, startup_response, did_authenticate=False): self.authenticator.__class__.__name__) log.debug("Got ReadyMessage on new connection (%s) from %s", id(self), self.endpoint) - self._enable_compression() + if self._compressor: + self.compressor = self._compressor if ProtocolVersion.has_checksumming_support(self.protocol_version): self._enable_checksumming() @@ -1348,10 +1345,6 @@ def _handle_startup_response(self, startup_response, did_authenticate=False): "if DSE authentication is configured with transitional mode" % (self.host,)) raise AuthenticationFailed('Remote end requires authentication') - self._enable_compression() - if ProtocolVersion.has_checksumming_support(self.protocol_version): - self._enable_checksumming() - if isinstance(self.authenticator, dict): log.debug("Sending credentials-based auth response on %s", self) cm = CredentialsMessage(creds=self.authenticator) diff --git a/cassandra/protocol.py b/cassandra/protocol.py index ed92a76679..c454824637 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -180,10 +180,6 @@ class ProtocolException(ErrorMessageSub): summary = 'Protocol error' error_code = 0x000A - @property - def is_beta_protocol_error(self): - return 'USE_BETA flag is unset' in str(self) - class BadCredentials(ErrorMessageSub): summary = 'Bad credentials' diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py index 9d350af707..1e1f582804 100644 --- a/tests/integration/__init__.py +++ b/tests/integration/__init__.py @@ -207,6 +207,8 @@ def get_default_protocol(): if DSE_VERSION: return ProtocolVersion.DSE_V2 else: + global ALLOW_BETA_PROTOCOL + ALLOW_BETA_PROTOCOL = True return ProtocolVersion.V5 if CASSANDRA_VERSION >= Version('3.10'): if DSE_VERSION: @@ -232,12 +234,9 @@ def get_supported_protocol_versions(): 3.X -> 4, 3 3.10(C*) -> 5(beta),4,3 3.10(DSE) -> DSE_V1,4,3 - 4.0(C*) -> 6(beta),5,4,3 + 4.0(C*) -> 5(beta),4,3 4.0(DSE) -> DSE_v2, DSE_V1,4,3 ` """ - if CASSANDRA_VERSION >= Version('4.0-beta5'): - if not DSE_VERSION: - return (3, 4, 5, 6) if CASSANDRA_VERSION >= Version('4.0-a'): if DSE_VERSION: return (3, 4, ProtocolVersion.DSE_V1, ProtocolVersion.DSE_V2) @@ -317,7 +316,7 @@ def _id_and_mark(f): notprotocolv1 = unittest.skipUnless(PROTOCOL_VERSION > 1, 'Protocol v1 not supported') lessthenprotocolv4 = unittest.skipUnless(PROTOCOL_VERSION < 4, 'Protocol versions 4 or greater not supported') greaterthanprotocolv3 = unittest.skipUnless(PROTOCOL_VERSION >= 4, 'Protocol versions less than 4 are not supported') -protocolv6 = unittest.skipUnless(6 in get_supported_protocol_versions(), 'Protocol versions less than 6 are not supported') +protocolv5 = unittest.skipUnless(5 in get_supported_protocol_versions(), 'Protocol versions less than 5 are not supported') greaterthancass20 = unittest.skipUnless(CASSANDRA_VERSION >= Version('2.1'), 'Cassandra version 2.1 or greater required') greaterthancass21 = unittest.skipUnless(CASSANDRA_VERSION >= Version('2.2'), 'Cassandra version 2.2 or greater required') greaterthanorequalcass30 = unittest.skipUnless(CASSANDRA_VERSION >= Version('3.0'), 'Cassandra version 3.0 or greater required') diff --git a/tests/integration/simulacron/test_empty_column.py b/tests/integration/simulacron/test_empty_column.py index 91c76985e1..bd7fe6ead0 100644 --- a/tests/integration/simulacron/test_empty_column.py +++ b/tests/integration/simulacron/test_empty_column.py @@ -27,8 +27,8 @@ from cassandra.cqlengine.connection import set_session from cassandra.cqlengine.models import Model -from tests.integration import requiressimulacron -from tests.integration.simulacron import PROTOCOL_VERSION, SimulacronCluster +from tests.integration import PROTOCOL_VERSION, requiressimulacron +from tests.integration.simulacron import SimulacronCluster from tests.integration.simulacron.utils import PrimeQuery, prime_request diff --git a/tests/integration/standard/test_cluster.py b/tests/integration/standard/test_cluster.py index c7d8266fd9..cdb6f1f3b7 100644 --- a/tests/integration/standard/test_cluster.py +++ b/tests/integration/standard/test_cluster.py @@ -42,7 +42,7 @@ from tests import notwindows from tests.integration import use_singledc, get_server_versions, CASSANDRA_VERSION, \ execute_until_pass, execute_with_long_wait_retry, get_node, MockLoggingHandler, get_unsupported_lower_protocol, \ - get_unsupported_upper_protocol, protocolv6, local, CASSANDRA_IP, greaterthanorequalcass30, lessthanorequalcass40, \ + get_unsupported_upper_protocol, protocolv5, local, CASSANDRA_IP, greaterthanorequalcass30, lessthanorequalcass40, \ DSE_VERSION, TestCluster, PROTOCOL_VERSION from tests.integration.util import assert_quiescent_pool_state import sys @@ -261,18 +261,6 @@ def test_protocol_negotiation(self): elif DSE_VERSION and DSE_VERSION >= Version("5.1"): self.assertEqual(updated_protocol_version, cassandra.ProtocolVersion.DSE_V1) self.assertEqual(updated_cluster_version, cassandra.ProtocolVersion.DSE_V1) - elif CASSANDRA_VERSION >= Version('4.0-beta5'): - self.assertEqual(updated_protocol_version, cassandra.ProtocolVersion.V5) - self.assertEqual(updated_cluster_version, cassandra.ProtocolVersion.V5) - elif CASSANDRA_VERSION >= Version('4.0-a'): - self.assertEqual(updated_protocol_version, cassandra.ProtocolVersion.V4) - self.assertEqual(updated_cluster_version, cassandra.ProtocolVersion.V4) - elif CASSANDRA_VERSION >= Version('3.11'): - self.assertEqual(updated_protocol_version, cassandra.ProtocolVersion.V4) - self.assertEqual(updated_cluster_version, cassandra.ProtocolVersion.V4) - elif CASSANDRA_VERSION >= Version('3.0'): - self.assertEqual(updated_protocol_version, cassandra.ProtocolVersion.V4) - self.assertEqual(updated_cluster_version, cassandra.ProtocolVersion.V4) elif CASSANDRA_VERSION >= Version('2.2'): self.assertEqual(updated_protocol_version, 4) self.assertEqual(updated_cluster_version, 4) @@ -1485,42 +1473,42 @@ def test_prepare_on_ignored_hosts(self): cluster.shutdown() -@protocolv6 +@protocolv5 class BetaProtocolTest(unittest.TestCase): - @protocolv6 + @protocolv5 def test_invalid_protocol_version_beta_option(self): """ - Test cluster connection with protocol v6 and beta flag not set + Test cluster connection with protocol v5 and beta flag not set @since 3.7.0 - @jira_ticket PYTHON-614, PYTHON-1232 - @expected_result client shouldn't connect with V6 and no beta flag set + @jira_ticket PYTHON-614 + @expected_result client shouldn't connect with V5 and no beta flag set @test_category connection """ - cluster = TestCluster(protocol_version=cassandra.ProtocolVersion.V6, allow_beta_protocol_version=False) + cluster = TestCluster(protocol_version=cassandra.ProtocolVersion.V5, allow_beta_protocol_version=False) try: with self.assertRaises(NoHostAvailable): cluster.connect() except Exception as e: self.fail("Unexpected error encountered {0}".format(e.message)) - @protocolv6 + @protocolv5 def test_valid_protocol_version_beta_options_connect(self): """ Test cluster connection with protocol version 5 and beta flag set @since 3.7.0 - @jira_ticket PYTHON-614, PYTHON-1232 - @expected_result client should connect with protocol v6 and beta flag set. + @jira_ticket PYTHON-614 + @expected_result client should connect with protocol v5 and beta flag set. @test_category connection """ - cluster = Cluster(protocol_version=cassandra.ProtocolVersion.V6, allow_beta_protocol_version=True) + cluster = Cluster(protocol_version=cassandra.ProtocolVersion.V5, allow_beta_protocol_version=True) session = cluster.connect() - self.assertEqual(cluster.protocol_version, cassandra.ProtocolVersion.V6) + self.assertEqual(cluster.protocol_version, cassandra.ProtocolVersion.V5) self.assertTrue(session.execute("select release_version from system.local")[0]) cluster.shutdown() diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index 620f642084..249c0a17cc 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -209,8 +209,6 @@ def test_protocol_downgrade_test(self): lower = ProtocolVersion.get_lower_supported(ProtocolVersion.DSE_V2) self.assertEqual(ProtocolVersion.DSE_V1, lower) lower = ProtocolVersion.get_lower_supported(ProtocolVersion.DSE_V1) - self.assertEqual(ProtocolVersion.V5,lower) - lower = ProtocolVersion.get_lower_supported(ProtocolVersion.V5) self.assertEqual(ProtocolVersion.V4,lower) lower = ProtocolVersion.get_lower_supported(ProtocolVersion.V4) self.assertEqual(ProtocolVersion.V3,lower) From db07a4dec9baa4eed46f286628c2287d3701afd9 Mon Sep 17 00:00:00 2001 From: Adam Holmberg Date: Tue, 23 Feb 2021 10:52:23 -0600 Subject: [PATCH 10/10] Revert "Revert "Merge branch 'master' into cassandra-test"" This reverts commit 12c5035a35852d43373dc2ade817974865c38a45. --- cassandra/__init__.py | 11 ++++-- cassandra/cluster.py | 14 ++++++-- cassandra/connection.py | 11 ++++-- cassandra/protocol.py | 4 +++ tests/integration/__init__.py | 9 ++--- .../simulacron/test_empty_column.py | 4 +-- tests/integration/standard/test_cluster.py | 36 ++++++++++++------- tests/unit/test_cluster.py | 2 ++ 8 files changed, 65 insertions(+), 26 deletions(-) diff --git a/cassandra/__init__.py b/cassandra/__init__.py index 100df2df17..1e16bca287 100644 --- a/cassandra/__init__.py +++ b/cassandra/__init__.py @@ -161,7 +161,12 @@ class ProtocolVersion(object): V5 = 5 """ - v5, in beta from 3.x+ + v5, in beta from 3.x+. Finalised in 4.0-beta5 + """ + + V6 = 6 + """ + v6, in beta from 4.0-beta5 """ DSE_V1 = 0x41 @@ -174,12 +179,12 @@ class ProtocolVersion(object): DSE private protocol v2, supported in DSE 6.0+ """ - SUPPORTED_VERSIONS = (DSE_V2, DSE_V1, V5, V4, V3, V2, V1) + SUPPORTED_VERSIONS = (DSE_V2, DSE_V1, V6, V5, V4, V3, V2, V1) """ A tuple of all supported protocol versions """ - BETA_VERSIONS = (V5,) + BETA_VERSIONS = (V6,) """ A tuple of all beta protocol versions """ diff --git a/cassandra/cluster.py b/cassandra/cluster.py index cedcf8207b..7e101afba8 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -63,7 +63,7 @@ BatchMessage, RESULT_KIND_PREPARED, RESULT_KIND_SET_KEYSPACE, RESULT_KIND_ROWS, RESULT_KIND_SCHEMA_CHANGE, ProtocolHandler, - RESULT_KIND_VOID) + RESULT_KIND_VOID, ProtocolException) from cassandra.metadata import Metadata, protect_name, murmur3, _NodeInfo from cassandra.policies import (TokenAwarePolicy, DCAwareRoundRobinPolicy, SimpleConvictionPolicy, ExponentialReconnectionPolicy, HostDistance, @@ -1570,7 +1570,7 @@ def set_core_connections_per_host(self, host_distance, core_connections): If :attr:`~.Cluster.protocol_version` is set to 3 or higher, this is not supported (there is always one connection per host, unless the host is remote and :attr:`connect_to_remote_hosts` is :const:`False`) - and using this will result in an :exc:`~.UnsupporteOperation`. + and using this will result in an :exc:`~.UnsupportedOperation`. """ if self.protocol_version >= 3: raise UnsupportedOperation( @@ -1603,7 +1603,7 @@ def set_max_connections_per_host(self, host_distance, max_connections): If :attr:`~.Cluster.protocol_version` is set to 3 or higher, this is not supported (there is always one connection per host, unless the host is remote and :attr:`connect_to_remote_hosts` is :const:`False`) - and using this will result in an :exc:`~.UnsupporteOperation`. + and using this will result in an :exc:`~.UnsupportedOperation`. """ if self.protocol_version >= 3: raise UnsupportedOperation( @@ -3548,6 +3548,14 @@ def _try_connect(self, host): break except ProtocolVersionUnsupported as e: self._cluster.protocol_downgrade(host.endpoint, e.startup_version) + except ProtocolException as e: + # protocol v5 is out of beta in C* >=4.0-beta5 and is now the default driver + # protocol version. If the protocol version was not explicitly specified, + # and that the server raises a beta protocol error, we should downgrade. + if not self._cluster._protocol_version_explicit and e.is_beta_protocol_error: + self._cluster.protocol_downgrade(host.endpoint, self._cluster.protocol_version) + else: + raise log.debug("[control connection] Established new connection %r, " "registering watchers and refreshing schema and topology", diff --git a/cassandra/connection.py b/cassandra/connection.py index 477eaf2f28..48b3caefed 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -897,6 +897,10 @@ def _connect_socket(self): for args in self.sockopts: self._socket.setsockopt(*args) + def _enable_compression(self): + if self._compressor: + self.compressor = self._compressor + def _enable_checksumming(self): self._io_buffer.set_checksumming_buffer() self._is_checksumming_enabled = True @@ -1328,8 +1332,7 @@ def _handle_startup_response(self, startup_response, did_authenticate=False): self.authenticator.__class__.__name__) log.debug("Got ReadyMessage on new connection (%s) from %s", id(self), self.endpoint) - if self._compressor: - self.compressor = self._compressor + self._enable_compression() if ProtocolVersion.has_checksumming_support(self.protocol_version): self._enable_checksumming() @@ -1345,6 +1348,10 @@ def _handle_startup_response(self, startup_response, did_authenticate=False): "if DSE authentication is configured with transitional mode" % (self.host,)) raise AuthenticationFailed('Remote end requires authentication') + self._enable_compression() + if ProtocolVersion.has_checksumming_support(self.protocol_version): + self._enable_checksumming() + if isinstance(self.authenticator, dict): log.debug("Sending credentials-based auth response on %s", self) cm = CredentialsMessage(creds=self.authenticator) diff --git a/cassandra/protocol.py b/cassandra/protocol.py index c454824637..ed92a76679 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -180,6 +180,10 @@ class ProtocolException(ErrorMessageSub): summary = 'Protocol error' error_code = 0x000A + @property + def is_beta_protocol_error(self): + return 'USE_BETA flag is unset' in str(self) + class BadCredentials(ErrorMessageSub): summary = 'Bad credentials' diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py index 1e1f582804..9d350af707 100644 --- a/tests/integration/__init__.py +++ b/tests/integration/__init__.py @@ -207,8 +207,6 @@ def get_default_protocol(): if DSE_VERSION: return ProtocolVersion.DSE_V2 else: - global ALLOW_BETA_PROTOCOL - ALLOW_BETA_PROTOCOL = True return ProtocolVersion.V5 if CASSANDRA_VERSION >= Version('3.10'): if DSE_VERSION: @@ -234,9 +232,12 @@ def get_supported_protocol_versions(): 3.X -> 4, 3 3.10(C*) -> 5(beta),4,3 3.10(DSE) -> DSE_V1,4,3 - 4.0(C*) -> 5(beta),4,3 + 4.0(C*) -> 6(beta),5,4,3 4.0(DSE) -> DSE_v2, DSE_V1,4,3 ` """ + if CASSANDRA_VERSION >= Version('4.0-beta5'): + if not DSE_VERSION: + return (3, 4, 5, 6) if CASSANDRA_VERSION >= Version('4.0-a'): if DSE_VERSION: return (3, 4, ProtocolVersion.DSE_V1, ProtocolVersion.DSE_V2) @@ -316,7 +317,7 @@ def _id_and_mark(f): notprotocolv1 = unittest.skipUnless(PROTOCOL_VERSION > 1, 'Protocol v1 not supported') lessthenprotocolv4 = unittest.skipUnless(PROTOCOL_VERSION < 4, 'Protocol versions 4 or greater not supported') greaterthanprotocolv3 = unittest.skipUnless(PROTOCOL_VERSION >= 4, 'Protocol versions less than 4 are not supported') -protocolv5 = unittest.skipUnless(5 in get_supported_protocol_versions(), 'Protocol versions less than 5 are not supported') +protocolv6 = unittest.skipUnless(6 in get_supported_protocol_versions(), 'Protocol versions less than 6 are not supported') greaterthancass20 = unittest.skipUnless(CASSANDRA_VERSION >= Version('2.1'), 'Cassandra version 2.1 or greater required') greaterthancass21 = unittest.skipUnless(CASSANDRA_VERSION >= Version('2.2'), 'Cassandra version 2.2 or greater required') greaterthanorequalcass30 = unittest.skipUnless(CASSANDRA_VERSION >= Version('3.0'), 'Cassandra version 3.0 or greater required') diff --git a/tests/integration/simulacron/test_empty_column.py b/tests/integration/simulacron/test_empty_column.py index bd7fe6ead0..91c76985e1 100644 --- a/tests/integration/simulacron/test_empty_column.py +++ b/tests/integration/simulacron/test_empty_column.py @@ -27,8 +27,8 @@ from cassandra.cqlengine.connection import set_session from cassandra.cqlengine.models import Model -from tests.integration import PROTOCOL_VERSION, requiressimulacron -from tests.integration.simulacron import SimulacronCluster +from tests.integration import requiressimulacron +from tests.integration.simulacron import PROTOCOL_VERSION, SimulacronCluster from tests.integration.simulacron.utils import PrimeQuery, prime_request diff --git a/tests/integration/standard/test_cluster.py b/tests/integration/standard/test_cluster.py index cdb6f1f3b7..c7d8266fd9 100644 --- a/tests/integration/standard/test_cluster.py +++ b/tests/integration/standard/test_cluster.py @@ -42,7 +42,7 @@ from tests import notwindows from tests.integration import use_singledc, get_server_versions, CASSANDRA_VERSION, \ execute_until_pass, execute_with_long_wait_retry, get_node, MockLoggingHandler, get_unsupported_lower_protocol, \ - get_unsupported_upper_protocol, protocolv5, local, CASSANDRA_IP, greaterthanorequalcass30, lessthanorequalcass40, \ + get_unsupported_upper_protocol, protocolv6, local, CASSANDRA_IP, greaterthanorequalcass30, lessthanorequalcass40, \ DSE_VERSION, TestCluster, PROTOCOL_VERSION from tests.integration.util import assert_quiescent_pool_state import sys @@ -261,6 +261,18 @@ def test_protocol_negotiation(self): elif DSE_VERSION and DSE_VERSION >= Version("5.1"): self.assertEqual(updated_protocol_version, cassandra.ProtocolVersion.DSE_V1) self.assertEqual(updated_cluster_version, cassandra.ProtocolVersion.DSE_V1) + elif CASSANDRA_VERSION >= Version('4.0-beta5'): + self.assertEqual(updated_protocol_version, cassandra.ProtocolVersion.V5) + self.assertEqual(updated_cluster_version, cassandra.ProtocolVersion.V5) + elif CASSANDRA_VERSION >= Version('4.0-a'): + self.assertEqual(updated_protocol_version, cassandra.ProtocolVersion.V4) + self.assertEqual(updated_cluster_version, cassandra.ProtocolVersion.V4) + elif CASSANDRA_VERSION >= Version('3.11'): + self.assertEqual(updated_protocol_version, cassandra.ProtocolVersion.V4) + self.assertEqual(updated_cluster_version, cassandra.ProtocolVersion.V4) + elif CASSANDRA_VERSION >= Version('3.0'): + self.assertEqual(updated_protocol_version, cassandra.ProtocolVersion.V4) + self.assertEqual(updated_cluster_version, cassandra.ProtocolVersion.V4) elif CASSANDRA_VERSION >= Version('2.2'): self.assertEqual(updated_protocol_version, 4) self.assertEqual(updated_cluster_version, 4) @@ -1473,42 +1485,42 @@ def test_prepare_on_ignored_hosts(self): cluster.shutdown() -@protocolv5 +@protocolv6 class BetaProtocolTest(unittest.TestCase): - @protocolv5 + @protocolv6 def test_invalid_protocol_version_beta_option(self): """ - Test cluster connection with protocol v5 and beta flag not set + Test cluster connection with protocol v6 and beta flag not set @since 3.7.0 - @jira_ticket PYTHON-614 - @expected_result client shouldn't connect with V5 and no beta flag set + @jira_ticket PYTHON-614, PYTHON-1232 + @expected_result client shouldn't connect with V6 and no beta flag set @test_category connection """ - cluster = TestCluster(protocol_version=cassandra.ProtocolVersion.V5, allow_beta_protocol_version=False) + cluster = TestCluster(protocol_version=cassandra.ProtocolVersion.V6, allow_beta_protocol_version=False) try: with self.assertRaises(NoHostAvailable): cluster.connect() except Exception as e: self.fail("Unexpected error encountered {0}".format(e.message)) - @protocolv5 + @protocolv6 def test_valid_protocol_version_beta_options_connect(self): """ Test cluster connection with protocol version 5 and beta flag set @since 3.7.0 - @jira_ticket PYTHON-614 - @expected_result client should connect with protocol v5 and beta flag set. + @jira_ticket PYTHON-614, PYTHON-1232 + @expected_result client should connect with protocol v6 and beta flag set. @test_category connection """ - cluster = Cluster(protocol_version=cassandra.ProtocolVersion.V5, allow_beta_protocol_version=True) + cluster = Cluster(protocol_version=cassandra.ProtocolVersion.V6, allow_beta_protocol_version=True) session = cluster.connect() - self.assertEqual(cluster.protocol_version, cassandra.ProtocolVersion.V5) + self.assertEqual(cluster.protocol_version, cassandra.ProtocolVersion.V6) self.assertTrue(session.execute("select release_version from system.local")[0]) cluster.shutdown() diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index 249c0a17cc..620f642084 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -209,6 +209,8 @@ def test_protocol_downgrade_test(self): lower = ProtocolVersion.get_lower_supported(ProtocolVersion.DSE_V2) self.assertEqual(ProtocolVersion.DSE_V1, lower) lower = ProtocolVersion.get_lower_supported(ProtocolVersion.DSE_V1) + self.assertEqual(ProtocolVersion.V5,lower) + lower = ProtocolVersion.get_lower_supported(ProtocolVersion.V5) self.assertEqual(ProtocolVersion.V4,lower) lower = ProtocolVersion.get_lower_supported(ProtocolVersion.V4) self.assertEqual(ProtocolVersion.V3,lower)