Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 96df657

Browse filesBrowse files
committed
Add 'BatchTransaction' wrapper class (#438)
Encapsulates session ID / transaction ID, to be marshalled across the wire to another process / host for performing partitioned reads / queries.
1 parent 6b62951 commit 96df657
Copy full SHA for 96df657

File tree

Expand file treeCollapse file tree

5 files changed

+1064
-30
lines changed
Filter options
Expand file treeCollapse file tree

5 files changed

+1064
-30
lines changed

‎spanner/google/cloud/spanner_v1/database.py

Copy file name to clipboardExpand all lines: spanner/google/cloud/spanner_v1/database.py
+266Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from google.cloud.spanner_v1._helpers import _metadata_with_prefix
2828
from google.cloud.spanner_v1.batch import Batch
2929
from google.cloud.spanner_v1.gapic.spanner_client import SpannerClient
30+
from google.cloud.spanner_v1.keyset import KeySet
3031
from google.cloud.spanner_v1.pool import BurstyPool
3132
from google.cloud.spanner_v1.pool import SessionCheckout
3233
from google.cloud.spanner_v1.session import Session
@@ -308,6 +309,14 @@ def batch(self):
308309
"""
309310
return BatchCheckout(self)
310311

312+
def batch_transaction(self):
313+
"""Return an object which wraps a batch read / query.
314+
315+
:rtype: :class:`~google.cloud.spanner_v1.database.BatchTransaction`
316+
:returns: new wrapper
317+
"""
318+
return BatchTransaction(self)
319+
311320
def run_in_transaction(self, func, *args, **kw):
312321
"""Perform a unit of work in a transaction, retrying on abort.
313322
@@ -406,6 +415,263 @@ def __exit__(self, exc_type, exc_val, exc_tb):
406415
self._database._pool.put(self._session)
407416

408417

418+
class BatchTransaction(object):
419+
"""Wrapper for generating and processing read / query batches.
420+
421+
:type database: :class:`~google.cloud.spannder.database.Database`
422+
:param database: database to use
423+
424+
:type read_timestamp: :class:`datetime.datetime`
425+
:param read_timestamp: Execute all reads at the given timestamp.
426+
427+
:type min_read_timestamp: :class:`datetime.datetime`
428+
:param min_read_timestamp: Execute all reads at a
429+
timestamp >= ``min_read_timestamp``.
430+
431+
:type max_staleness: :class:`datetime.timedelta`
432+
:param max_staleness: Read data at a
433+
timestamp >= NOW - ``max_staleness`` seconds.
434+
435+
:type exact_staleness: :class:`datetime.timedelta`
436+
:param exact_staleness: Execute all reads at a timestamp that is
437+
``exact_staleness`` old.
438+
"""
439+
def __init__(
440+
self, database,
441+
read_timestamp=None,
442+
min_read_timestamp=None,
443+
max_staleness=None,
444+
exact_staleness=None):
445+
446+
self._database = database
447+
self._session = None
448+
self._snapshot = None
449+
self._read_timestamp = read_timestamp
450+
self._min_read_timestamp = min_read_timestamp
451+
self._max_staleness = max_staleness
452+
self._exact_staleness = exact_staleness
453+
454+
@classmethod
455+
def from_dict(cls, database, mapping):
456+
"""Reconstruct an instance from a mapping.
457+
458+
:type database: :class:`~google.cloud.spannder.database.Database`
459+
:param database: database to use
460+
461+
:type mapping: mapping
462+
:param mapping: serialized state of the instance
463+
464+
:rtype: :class:`BatchTransaction`
465+
"""
466+
instance = cls(database)
467+
session = instance._session = database.session()
468+
session._session_id = mapping['session_id']
469+
txn = session.transaction()
470+
txn._transaction_id = mapping['transaction_id']
471+
return instance
472+
473+
def to_dict(self):
474+
"""Return state as a dictionary.
475+
476+
Result can be used to serialize the instance and reconstitute
477+
it later using :meth:`from_dict`.
478+
479+
:rtype: dict
480+
"""
481+
session = self._get_session()
482+
return {
483+
'session_id': session._session_id,
484+
'transaction_id': session._transaction._transaction_id,
485+
}
486+
487+
def _get_session(self):
488+
"""Create session as needed.
489+
490+
.. note::
491+
492+
Caller is responsible for cleaning up the session after
493+
all partitions have been processed.
494+
"""
495+
if self._session is None:
496+
session = self._session = self._database.session()
497+
session.create()
498+
txn = session.transaction()
499+
txn.begin()
500+
return self._session
501+
502+
def _get_snapshot(self):
503+
"""Create snapshot if needed."""
504+
if self._snapshot is None:
505+
self._snapshot = self._get_session().snapshot(
506+
read_timestamp=self._read_timestamp,
507+
min_read_timestamp=self._min_read_timestamp,
508+
max_staleness=self._max_staleness,
509+
exact_staleness=self._exact_staleness,
510+
multi_use=True)
511+
return self._snapshot
512+
513+
def generate_read_batches(
514+
self, table, columns, keyset,
515+
index='', partition_size_bytes=None, max_partitions=None):
516+
"""Start a partitioned batch read operation.
517+
518+
Uses the ``PartitionRead`` API request to initiate the partitioned
519+
read. Returns a list of batch information needed to perform the
520+
actual reads.
521+
522+
:type table: str
523+
:param table: name of the table from which to fetch data
524+
525+
:type columns: list of str
526+
:param columns: names of columns to be retrieved
527+
528+
:type keyset: :class:`~google.cloud.spanner_v1.keyset.KeySet`
529+
:param keyset: keys / ranges identifying rows to be retrieved
530+
531+
:type index: str
532+
:param index: (Optional) name of index to use, rather than the
533+
table's primary key
534+
535+
:type partition_size_bytes: int
536+
:param partition_size_bytes:
537+
(Optional) desired size for each partition generated. The service
538+
uses this as a hint, the actual partition size may differ.
539+
540+
:type max_partitions: int
541+
:param max_partitions:
542+
(Optional) desired maximum number of partitions generated. The
543+
service uses this as a hint, the actual number of partitions may
544+
differ.
545+
546+
:rtype: iterable of dict
547+
:returns:
548+
mappings of information used peform actual partitioned reads via
549+
:meth:`process_read_batch`.
550+
"""
551+
partitions = self._get_snapshot().partition_read(
552+
table=table, columns=columns, keyset=keyset, index=index,
553+
partition_size_bytes=partition_size_bytes,
554+
max_partitions=max_partitions)
555+
556+
read_info = {
557+
'table': table,
558+
'columns': columns,
559+
'keyset': keyset._to_dict(),
560+
'index': index,
561+
}
562+
for partition in partitions:
563+
yield {'partition': partition, 'read': read_info.copy()}
564+
565+
def process_read_batch(self, batch):
566+
"""Process a single, partitioned read.
567+
568+
:type batch: mapping
569+
:param batch:
570+
one of the mappings returned from an earlier call to
571+
:meth:`generate_read_batches`.
572+
573+
:rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet`
574+
:returns: a result set instance which can be used to consume rows.
575+
"""
576+
kwargs = batch['read']
577+
keyset_dict = kwargs.pop('keyset')
578+
kwargs['keyset'] = KeySet._from_dict(keyset_dict)
579+
return self._get_snapshot().read(
580+
partition=batch['partition'], **kwargs)
581+
582+
def generate_query_batches(
583+
self, sql, params=None, param_types=None,
584+
partition_size_bytes=None, max_partitions=None):
585+
"""Start a partitioned query operation.
586+
587+
Uses the ``PartitionQuery`` API request to start a partitioned
588+
query operation. Returns a list of batch information needed to
589+
peform the actual queries.
590+
591+
:type sql: str
592+
:param sql: SQL query statement
593+
594+
:type params: dict, {str -> column value}
595+
:param params: values for parameter replacement. Keys must match
596+
the names used in ``sql``.
597+
598+
:type param_types: dict[str -> Union[dict, .types.Type]]
599+
:param param_types:
600+
(Optional) maps explicit types for one or more param values;
601+
required if parameters are passed.
602+
603+
:type partition_size_bytes: int
604+
:param partition_size_bytes:
605+
(Optional) desired size for each partition generated. The service
606+
uses this as a hint, the actual partition size may differ.
607+
608+
:type partition_size_bytes: int
609+
:param partition_size_bytes:
610+
(Optional) desired size for each partition generated. The service
611+
uses this as a hint, the actual partition size may differ.
612+
613+
:type max_partitions: int
614+
:param max_partitions:
615+
(Optional) desired maximum number of partitions generated. The
616+
service uses this as a hint, the actual number of partitions may
617+
differ.
618+
619+
:rtype: iterable of dict
620+
:returns:
621+
mappings of information used peform actual partitioned reads via
622+
:meth:`process_read_batch`.
623+
"""
624+
partitions = self._get_snapshot().partition_query(
625+
sql=sql, params=params, param_types=param_types,
626+
partition_size_bytes=partition_size_bytes,
627+
max_partitions=max_partitions)
628+
629+
query_info = {'sql': sql}
630+
if params:
631+
query_info['params'] = params
632+
query_info['param_types'] = param_types
633+
634+
for partition in partitions:
635+
yield {'partition': partition, 'query': query_info}
636+
637+
def process_query_batch(self, batch):
638+
"""Process a single, partitioned query.
639+
640+
:type batch: mapping
641+
:param batch:
642+
one of the mappings returned from an earlier call to
643+
:meth:`generate_query_batches`.
644+
645+
:rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet`
646+
:returns: a result set instance which can be used to consume rows.
647+
"""
648+
return self._get_snapshot().execute_sql(
649+
partition=batch['partition'], **batch['query'])
650+
651+
def process(self, batch):
652+
"""Process a single, partitioned query or read.
653+
654+
:type batch: mapping
655+
:param batch:
656+
one of the mappings returned from an earlier call to
657+
:meth:`generate_query_batches`.
658+
659+
:rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet`
660+
:returns: a result set instance which can be used to consume rows.
661+
:raises ValueError: if batch does not contain either 'read' or 'query'
662+
"""
663+
if 'query' in batch:
664+
return self.process_query_batch(batch)
665+
if 'read' in batch:
666+
return self.process_read_batch(batch)
667+
raise ValueError("Invalid batch")
668+
669+
def close(self):
670+
"""Clean up underlying session."""
671+
if self._session is not None:
672+
self._session.delete()
673+
674+
409675
def _check_ddl_statements(value):
410676
"""Validate DDL Statements used to define database schema.
411677

‎spanner/google/cloud/spanner_v1/keyset.py

Copy file name to clipboardExpand all lines: spanner/google/cloud/spanner_v1/keyset.py
+67Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,35 @@ def _to_pb(self):
8585

8686
return KeyRangePB(**kwargs)
8787

88+
def _to_dict(self):
89+
"""Return keyrange's state as a dict.
90+
91+
:rtype: dict
92+
:returns: state of this instance.
93+
"""
94+
mapping = {}
95+
96+
if self.start_open:
97+
mapping['start_open'] = self.start_open
98+
99+
if self.start_closed:
100+
mapping['start_closed'] = self.start_closed
101+
102+
if self.end_open:
103+
mapping['end_open'] = self.end_open
104+
105+
if self.end_closed:
106+
mapping['end_closed'] = self.end_closed
107+
108+
return mapping
109+
110+
def __eq__(self, other):
111+
"""Compare by serialized state."""
112+
if not isinstance(other, self.__class__):
113+
return NotImplemented
114+
return self._to_dict() == other._to_dict()
115+
116+
88117

89118
class KeySet(object):
90119
"""Identify table rows via keys / ranges.
@@ -122,3 +151,41 @@ def _to_pb(self):
122151
kwargs['ranges'] = [krange._to_pb() for krange in self.ranges]
123152

124153
return KeySetPB(**kwargs)
154+
155+
def _to_dict(self):
156+
"""Return keyset's state as a dict.
157+
158+
The result can be used to serialize the instance and reconstitute
159+
it later using :meth:`_from_dict`.
160+
161+
:rtype: dict
162+
:returns: state of this instance.
163+
"""
164+
if self.all_:
165+
return {'all': True}
166+
167+
return {
168+
'keys': self.keys,
169+
'ranges': [keyrange._to_dict() for keyrange in self.ranges],
170+
}
171+
172+
def __eq__(self, other):
173+
"""Compare by serialized state."""
174+
if not isinstance(other, self.__class__):
175+
return NotImplemented
176+
return self._to_dict() == other._to_dict()
177+
178+
@classmethod
179+
def _from_dict(cls, mapping):
180+
"""Create an instance from the corresponding state mapping.
181+
182+
:type mapping: dict
183+
:param mapping: the instance state.
184+
"""
185+
if mapping.get('all'):
186+
return cls(all_=True)
187+
188+
r_mappings = mapping.get('ranges', ())
189+
ranges = [KeyRange(**r_mapping) for r_mapping in r_mappings]
190+
191+
return cls(keys=mapping.get('keys', ()), ranges=ranges)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.